bumchik2 commited on
Commit
5258578
·
1 Parent(s): c7f1481

adding notebooks

Browse files
notebooks/albert_base_v2_main.ipynb ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 27,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stdout",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "4.48.3\n"
13
+ ]
14
+ },
15
+ {
16
+ "data": {
17
+ "text/plain": [
18
+ "device(type='cuda')"
19
+ ]
20
+ },
21
+ "execution_count": 27,
22
+ "metadata": {},
23
+ "output_type": "execute_result"
24
+ }
25
+ ],
26
+ "source": [
27
+ "from transformers import pipeline\n",
28
+ "import json\n",
29
+ "import pandas as pd\n",
30
+ "from sklearn.model_selection import train_test_split\n",
31
+ "from transformers import AlbertTokenizer\n",
32
+ "from tqdm import tqdm\n",
33
+ "import re\n",
34
+ "from datasets import Dataset\n",
35
+ "from transformers import AutoModelForSequenceClassification\n",
36
+ "import torch\n",
37
+ "import numpy as np\n",
38
+ "from typing import Dict\n",
39
+ "from transformers import AutoModel\n",
40
+ "from torch.nn import BCEWithLogitsLoss\n",
41
+ "from typing import List\n",
42
+ "from transformers import TrainingArguments, Trainer\n",
43
+ "from collections import defaultdict\n",
44
+ "\n",
45
+ "from transformers import __version__ as transformers_version\n",
46
+ "print(transformers_version)\n",
47
+ "\n",
48
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
49
+ "DEVICE"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 2,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "USED_MODEL = \"albert-base-v2\""
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": 3,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "def read_json(json_filename):\n",
68
+ " with open(json_filename, 'r') as f:\n",
69
+ " return json.loads(f.read())\n",
70
+ "\n",
71
+ "\n",
72
+ "def save_json(json_object, json_filename, indent=4):\n",
73
+ " with open(json_filename, 'w') as f:\n",
74
+ " json.dump(json_object, f, separators=(',', ':'), indent=indent)"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "markdown",
79
+ "metadata": {},
80
+ "source": [
81
+ "**Данные берем отсюда: https://www.kaggle.com/datasets/neelshah18/arxivdataset**"
82
+ ]
83
+ },
84
+ {
85
+ "cell_type": "code",
86
+ "execution_count": 4,
87
+ "metadata": {},
88
+ "outputs": [],
89
+ "source": [
90
+ "arxiv_data = read_json('arxivData.json')"
91
+ ]
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "execution_count": 5,
96
+ "metadata": {},
97
+ "outputs": [
98
+ {
99
+ "data": {
100
+ "text/plain": [
101
+ "{'author': \"[{'name': 'Ahmed Osman'}, {'name': 'Wojciech Samek'}]\",\n",
102
+ " 'day': 1,\n",
103
+ " 'id': '1802.00209v1',\n",
104
+ " 'link': \"[{'rel': 'alternate', 'href': 'http://arxiv.org/abs/1802.00209v1', 'type': 'text/html'}, {'rel': 'related', 'href': 'http://arxiv.org/pdf/1802.00209v1', 'type': 'application/pdf', 'title': 'pdf'}]\",\n",
105
+ " 'month': 2,\n",
106
+ " 'summary': 'We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.',\n",
107
+ " 'tag': \"[{'term': 'cs.AI', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CL', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CV', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.NE', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'stat.ML', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}]\",\n",
108
+ " 'title': 'Dual Recurrent Attention Units for Visual Question Answering',\n",
109
+ " 'year': 2018}"
110
+ ]
111
+ },
112
+ "execution_count": 5,
113
+ "metadata": {},
114
+ "output_type": "execute_result"
115
+ }
116
+ ],
117
+ "source": [
118
+ "arxiv_data[0]"
119
+ ]
120
+ },
121
+ {
122
+ "cell_type": "markdown",
123
+ "metadata": {},
124
+ "source": [
125
+ "**Хотим по названию статьи + abstract выдавать наиболее вероятную тематику статьи, скажем, физика, биология или computer science** "
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": 6,
131
+ "metadata": {},
132
+ "outputs": [
133
+ {
134
+ "name": "stdout",
135
+ "output_type": "stream",
136
+ "text": [
137
+ "155\n"
138
+ ]
139
+ },
140
+ {
141
+ "data": {
142
+ "text/html": [
143
+ "<div>\n",
144
+ "<style scoped>\n",
145
+ " .dataframe tbody tr th:only-of-type {\n",
146
+ " vertical-align: middle;\n",
147
+ " }\n",
148
+ "\n",
149
+ " .dataframe tbody tr th {\n",
150
+ " vertical-align: top;\n",
151
+ " }\n",
152
+ "\n",
153
+ " .dataframe thead th {\n",
154
+ " text-align: right;\n",
155
+ " }\n",
156
+ "</style>\n",
157
+ "<table border=\"1\" class=\"dataframe\">\n",
158
+ " <thead>\n",
159
+ " <tr style=\"text-align: right;\">\n",
160
+ " <th></th>\n",
161
+ " <th>tag</th>\n",
162
+ " <th>topic</th>\n",
163
+ " <th>category</th>\n",
164
+ " </tr>\n",
165
+ " </thead>\n",
166
+ " <tbody>\n",
167
+ " <tr>\n",
168
+ " <th>0</th>\n",
169
+ " <td>cs.AI</td>\n",
170
+ " <td>Artificial Intelligence</td>\n",
171
+ " <td>Computer Science</td>\n",
172
+ " </tr>\n",
173
+ " <tr>\n",
174
+ " <th>1</th>\n",
175
+ " <td>cs.AR</td>\n",
176
+ " <td>Hardware Architecture</td>\n",
177
+ " <td>Computer Science</td>\n",
178
+ " </tr>\n",
179
+ " <tr>\n",
180
+ " <th>2</th>\n",
181
+ " <td>cs.CC</td>\n",
182
+ " <td>Computational Complexity</td>\n",
183
+ " <td>Computer Science</td>\n",
184
+ " </tr>\n",
185
+ " <tr>\n",
186
+ " <th>3</th>\n",
187
+ " <td>cs.CE</td>\n",
188
+ " <td>Computational Engineering, Finance, and Science</td>\n",
189
+ " <td>Computer Science</td>\n",
190
+ " </tr>\n",
191
+ " <tr>\n",
192
+ " <th>4</th>\n",
193
+ " <td>cs.CG</td>\n",
194
+ " <td>Computational Geometry</td>\n",
195
+ " <td>Computer Science</td>\n",
196
+ " </tr>\n",
197
+ " </tbody>\n",
198
+ "</table>\n",
199
+ "</div>"
200
+ ],
201
+ "text/plain": [
202
+ " tag topic category\n",
203
+ "0 cs.AI Artificial Intelligence Computer Science\n",
204
+ "1 cs.AR Hardware Architecture Computer Science\n",
205
+ "2 cs.CC Computational Complexity Computer Science\n",
206
+ "3 cs.CE Computational Engineering, Finance, and Science Computer Science\n",
207
+ "4 cs.CG Computational Geometry Computer Science"
208
+ ]
209
+ },
210
+ "execution_count": 6,
211
+ "metadata": {},
212
+ "output_type": "execute_result"
213
+ }
214
+ ],
215
+ "source": [
216
+ "arxiv_topics_df = pd.read_csv('arxiv_topics.csv')\n",
217
+ "print(len(arxiv_topics_df))\n",
218
+ "arxiv_topics_df.head(5)"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": 7,
224
+ "metadata": {},
225
+ "outputs": [],
226
+ "source": [
227
+ "category_to_index = {}\n",
228
+ "tag_to_category = {}\n",
229
+ "current_index = 0\n",
230
+ "for i, row in arxiv_topics_df.iterrows():\n",
231
+ " category = row['category']\n",
232
+ " if category not in category_to_index:\n",
233
+ " category_to_index[category] = current_index\n",
234
+ " current_index += 1\n",
235
+ " tag_to_category[row['tag']] = row['category']\n",
236
+ "index_to_category = {value: key for key, value in category_to_index.items()}"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "markdown",
241
+ "metadata": {},
242
+ "source": [
243
+ "**Готовим данные к обучению**"
244
+ ]
245
+ },
246
+ {
247
+ "cell_type": "code",
248
+ "execution_count": 8,
249
+ "metadata": {},
250
+ "outputs": [
251
+ {
252
+ "name": "stderr",
253
+ "output_type": "stream",
254
+ "text": [
255
+ " 0%| | 0/41000 [00:00<?, ?it/s]"
256
+ ]
257
+ },
258
+ {
259
+ "name": "stderr",
260
+ "output_type": "stream",
261
+ "text": [
262
+ "100%|██████████| 41000/41000 [00:01<00:00, 33667.39it/s]"
263
+ ]
264
+ },
265
+ {
266
+ "name": "stdout",
267
+ "output_type": "stream",
268
+ "text": [
269
+ "Среднее число категорий в одной статье: 1.3301219512195122\n",
270
+ "Среднее число тегов в одной статье: 1.8489024390243902\n"
271
+ ]
272
+ },
273
+ {
274
+ "name": "stderr",
275
+ "output_type": "stream",
276
+ "text": [
277
+ "\n"
278
+ ]
279
+ }
280
+ ],
281
+ "source": [
282
+ "def is_valid_tag(tag: str) -> bool:\n",
283
+ " return tag in tag_to_category\n",
284
+ "\n",
285
+ "total_categories_count = 0\n",
286
+ "total_tags_count = 0\n",
287
+ "records = []\n",
288
+ "for arxiv_record in tqdm(arxiv_data):\n",
289
+ " record = {\n",
290
+ " 'title': arxiv_record['title'],\n",
291
+ " 'summary': arxiv_record['summary'],\n",
292
+ " 'title_and_summary': arxiv_record['title'] + ' $ ' + arxiv_record['summary'],\n",
293
+ " 'tags': [current_tag['term'] for current_tag in eval(arxiv_record['tag']) if is_valid_tag(current_tag['term'])]\n",
294
+ " }\n",
295
+ " categories = set(tag_to_category[tag] for tag in record['tags'])\n",
296
+ " total_categories_count += len(categories)\n",
297
+ " total_tags_count += len(record['tags'])\n",
298
+ " record['categories_indices'] = list(set([category_to_index[tag_to_category[tag]] for tag in record['tags']]))\n",
299
+ " assert len(record['tags']) > 0\n",
300
+ " records.append(record)\n",
301
+ "\n",
302
+ "print(f'Среднее число категорий в одной статье: {total_categories_count / len(arxiv_data)}')\n",
303
+ "print(f'Среднее число тегов в одной статье: {total_tags_count / len(arxiv_data)}')"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "markdown",
308
+ "metadata": {},
309
+ "source": [
310
+ "Как видим, перед нами задача мультибинарной классификации.\n",
311
+ "\n",
312
+ "Тегов у одной статьи бывает много, это понятно, но и категорий тоже бывает много. То есть, условно статья может быть посвящена и физике и биологии одновременно.\n",
313
+ "\n",
314
+ "Попробуем обучить модель определять теги - так она потенциально может сохранить в себе больше информации, чем если ее обучить определять категории (которых гораздо меньше)."
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "markdown",
319
+ "metadata": {},
320
+ "source": [
321
+ "**Соединяем title и summary используя символ `$` - он редкий, при этом его знает токенайзер, поэтому не придется с ним дополнительно возиться**"
322
+ ]
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "execution_count": 9,
327
+ "metadata": {},
328
+ "outputs": [
329
+ {
330
+ "name": "stdout",
331
+ "output_type": "stream",
332
+ "text": [
333
+ "41000\n"
334
+ ]
335
+ },
336
+ {
337
+ "data": {
338
+ "text/html": [
339
+ "<div>\n",
340
+ "<style scoped>\n",
341
+ " .dataframe tbody tr th:only-of-type {\n",
342
+ " vertical-align: middle;\n",
343
+ " }\n",
344
+ "\n",
345
+ " .dataframe tbody tr th {\n",
346
+ " vertical-align: top;\n",
347
+ " }\n",
348
+ "\n",
349
+ " .dataframe thead th {\n",
350
+ " text-align: right;\n",
351
+ " }\n",
352
+ "</style>\n",
353
+ "<table border=\"1\" class=\"dataframe\">\n",
354
+ " <thead>\n",
355
+ " <tr style=\"text-align: right;\">\n",
356
+ " <th></th>\n",
357
+ " <th>title</th>\n",
358
+ " <th>summary</th>\n",
359
+ " <th>title_and_summary</th>\n",
360
+ " <th>tags</th>\n",
361
+ " <th>categories_indices</th>\n",
362
+ " </tr>\n",
363
+ " </thead>\n",
364
+ " <tbody>\n",
365
+ " <tr>\n",
366
+ " <th>0</th>\n",
367
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
368
+ " <td>We propose an architecture for VQA which utili...</td>\n",
369
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
370
+ " <td>[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML]</td>\n",
371
+ " <td>[0, 7]</td>\n",
372
+ " </tr>\n",
373
+ " <tr>\n",
374
+ " <th>1</th>\n",
375
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
376
+ " <td>Recent approaches based on artificial neural n...</td>\n",
377
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
378
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
379
+ " <td>[0, 7]</td>\n",
380
+ " </tr>\n",
381
+ " <tr>\n",
382
+ " <th>2</th>\n",
383
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
384
+ " <td>We introduce the multiresolution recurrent neu...</td>\n",
385
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
386
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
387
+ " <td>[0, 7]</td>\n",
388
+ " </tr>\n",
389
+ " <tr>\n",
390
+ " <th>3</th>\n",
391
+ " <td>Learning what to share between loosely related...</td>\n",
392
+ " <td>Multi-task learning is motivated by the observ...</td>\n",
393
+ " <td>Learning what to share between loosely related...</td>\n",
394
+ " <td>[stat.ML, cs.AI, cs.CL, cs.LG, cs.NE]</td>\n",
395
+ " <td>[0, 7]</td>\n",
396
+ " </tr>\n",
397
+ " <tr>\n",
398
+ " <th>4</th>\n",
399
+ " <td>A Deep Reinforcement Learning Chatbot</td>\n",
400
+ " <td>We present MILABOT: a deep reinforcement learn...</td>\n",
401
+ " <td>A Deep Reinforcement Learning Chatbot $ We pre...</td>\n",
402
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
403
+ " <td>[0, 7]</td>\n",
404
+ " </tr>\n",
405
+ " </tbody>\n",
406
+ "</table>\n",
407
+ "</div>"
408
+ ],
409
+ "text/plain": [
410
+ " title \\\n",
411
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
412
+ "1 Sequential Short-Text Classification with Recu... \n",
413
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
414
+ "3 Learning what to share between loosely related... \n",
415
+ "4 A Deep Reinforcement Learning Chatbot \n",
416
+ "\n",
417
+ " summary \\\n",
418
+ "0 We propose an architecture for VQA which utili... \n",
419
+ "1 Recent approaches based on artificial neural n... \n",
420
+ "2 We introduce the multiresolution recurrent neu... \n",
421
+ "3 Multi-task learning is motivated by the observ... \n",
422
+ "4 We present MILABOT: a deep reinforcement learn... \n",
423
+ "\n",
424
+ " title_and_summary \\\n",
425
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
426
+ "1 Sequential Short-Text Classification with Recu... \n",
427
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
428
+ "3 Learning what to share between loosely related... \n",
429
+ "4 A Deep Reinforcement Learning Chatbot $ We pre... \n",
430
+ "\n",
431
+ " tags categories_indices \n",
432
+ "0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] [0, 7] \n",
433
+ "1 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n",
434
+ "2 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n",
435
+ "3 [stat.ML, cs.AI, cs.CL, cs.LG, cs.NE] [0, 7] \n",
436
+ "4 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] "
437
+ ]
438
+ },
439
+ "execution_count": 9,
440
+ "metadata": {},
441
+ "output_type": "execute_result"
442
+ }
443
+ ],
444
+ "source": [
445
+ "full_data_df = pd.DataFrame(records)\n",
446
+ "print(len(full_data_df))\n",
447
+ "full_data_df.head(5)"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "markdown",
452
+ "metadata": {},
453
+ "source": [
454
+ "**Как видим, Computer science встречается очень часто. А, например, экономика - совсем редко. Значит при обучении экономике логично давать больше вес**"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": 10,
460
+ "metadata": {},
461
+ "outputs": [],
462
+ "source": [
463
+ "text_data = list(full_data_df['title_and_summary'])\n",
464
+ "categories_indices = list(full_data_df['categories_indices'])"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": 11,
470
+ "metadata": {},
471
+ "outputs": [
472
+ {
473
+ "name": "stdout",
474
+ "output_type": "stream",
475
+ "text": [
476
+ "28700 8200 4100\n"
477
+ ]
478
+ }
479
+ ],
480
+ "source": [
481
+ "X_train_val, X_test, y_train_val, y_test = train_test_split(text_data, categories_indices, test_size=0.1, random_state=42)\n",
482
+ "X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=2/9, random_state=42)\n",
483
+ "print(len(X_train), len(X_val), len(X_test))\n",
484
+ "# Train is 70%, val is 20%, test is 10%"
485
+ ]
486
+ },
487
+ {
488
+ "cell_type": "markdown",
489
+ "metadata": {},
490
+ "source": [
491
+ "Посмотрим на распределение категорий в тренировочной выборке"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": 12,
497
+ "metadata": {},
498
+ "outputs": [
499
+ {
500
+ "name": "stdout",
501
+ "output_type": "stream",
502
+ "text": [
503
+ "{0: 27475, 3: 1591, 7: 7417, 5: 623, 2: 152, 4: 840, 6: 43, 1: 9}\n"
504
+ ]
505
+ }
506
+ ],
507
+ "source": [
508
+ "category_to_count = defaultdict(int)\n",
509
+ "for row in y_train:\n",
510
+ " for category in row:\n",
511
+ " category_to_count[category] += 1\n",
512
+ "category_to_count = dict(category_to_count)\n",
513
+ "print(category_to_count)"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": 13,
519
+ "metadata": {},
520
+ "outputs": [],
521
+ "source": [
522
+ "tokenizer = AlbertTokenizer.from_pretrained(USED_MODEL)\n",
523
+ "def tokenize_function(text):\n",
524
+ " return tokenizer(text, padding=\"max_length\", truncation=True)"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": 14,
530
+ "metadata": {},
531
+ "outputs": [
532
+ {
533
+ "name": "stdout",
534
+ "output_type": "stream",
535
+ "text": [
536
+ "<class 'transformers.tokenization_utils_base.BatchEncoding'>\n",
537
+ "['_MutableMapping__marker', '__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__or__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_encodings', '_n_sequences', 'char_to_token', 'char_to_word', 'clear', 'convert_to_tensors', 'copy', 'data', 'encodings', 'fromkeys', 'get', 'is_fast', 'items', 'keys', 'n_sequences', 'pop', 'popitem', 'sequence_ids', 'setdefault', 'to', 'token_to_chars', 'token_to_sequence', 'token_to_word', 'tokens', 'update', 'values', 'word_ids', 'word_to_chars', 'word_to_tokens', 'words']\n",
538
+ "3\n"
539
+ ]
540
+ }
541
+ ],
542
+ "source": [
543
+ "train_encodings = tokenize_function(X_train)\n",
544
+ "val_encodings = tokenize_function(X_val)\n",
545
+ "test_encodings = tokenize_function(X_test)\n",
546
+ "\n",
547
+ "print(type(train_encodings))\n",
548
+ "print(dir(train_encodings))\n",
549
+ "print(len(train_encodings))"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "execution_count": 15,
555
+ "metadata": {},
556
+ "outputs": [
557
+ {
558
+ "name": "stderr",
559
+ "output_type": "stream",
560
+ "text": [
561
+ "100%|██████████| 28700/28700 [00:00<00:00, 506098.43it/s]\n",
562
+ "100%|██████████| 8200/8200 [00:00<00:00, 533767.25it/s]\n",
563
+ "100%|██████████| 4100/4100 [00:00<00:00, 516059.37it/s]\n"
564
+ ]
565
+ }
566
+ ],
567
+ "source": [
568
+ "def get_labels(y: List[List[int]]):\n",
569
+ " labels = np.zeros((len(y), len(category_to_index)))\n",
570
+ " for i in tqdm(range(len(y))):\n",
571
+ " labels[i, y[i]] = 1\n",
572
+ " return labels.tolist()\n",
573
+ "\n",
574
+ "labels_train = get_labels(y_train)\n",
575
+ "labels_val = get_labels(y_val)\n",
576
+ "labels_test = get_labels(y_test)"
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": 16,
582
+ "metadata": {},
583
+ "outputs": [],
584
+ "source": [
585
+ "train_encodings['labels'] = labels_train\n",
586
+ "val_encodings['labels'] = labels_val\n",
587
+ "test_encodings['labels'] = labels_test"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "markdown",
592
+ "metadata": {},
593
+ "source": [
594
+ "**Я использовал пример отсюда чтобы понимать, какой нужен формат данных https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb**"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": 17,
600
+ "metadata": {},
601
+ "outputs": [],
602
+ "source": [
603
+ "train_dataset = Dataset.from_dict(train_encodings)\n",
604
+ "val_dataset = Dataset.from_dict(val_encodings)\n",
605
+ "test_dataset = Dataset.from_dict(test_encodings)\n",
606
+ "\n",
607
+ "train_dataset.set_format(\"torch\")\n",
608
+ "val_dataset.set_format(\"torch\")\n",
609
+ "test_dataset.set_format(\"torch\")"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "code",
614
+ "execution_count": 18,
615
+ "metadata": {},
616
+ "outputs": [
617
+ {
618
+ "name": "stderr",
619
+ "output_type": "stream",
620
+ "text": [
621
+ "Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
622
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
623
+ ]
624
+ }
625
+ ],
626
+ "source": [
627
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
628
+ " USED_MODEL, \n",
629
+ " problem_type=\"multi_label_classification\", \n",
630
+ " num_labels=len(category_to_index),\n",
631
+ " id2label=index_to_category,\n",
632
+ " label2id=category_to_index\n",
633
+ ")"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": 19,
639
+ "metadata": {},
640
+ "outputs": [],
641
+ "source": [
642
+ "batch_size = 8\n",
643
+ "metric_name = \"f1\""
644
+ ]
645
+ },
646
+ {
647
+ "cell_type": "code",
648
+ "execution_count": 20,
649
+ "metadata": {},
650
+ "outputs": [
651
+ {
652
+ "name": "stderr",
653
+ "output_type": "stream",
654
+ "text": [
655
+ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
656
+ " warnings.warn(\n"
657
+ ]
658
+ }
659
+ ],
660
+ "source": [
661
+ "args = TrainingArguments(\n",
662
+ " output_dir=f'train-{USED_MODEL}',\n",
663
+ " evaluation_strategy=\"epoch\",\n",
664
+ " save_strategy=\"epoch\",\n",
665
+ " learning_rate=2e-5,\n",
666
+ " per_device_train_batch_size=batch_size,\n",
667
+ " per_device_eval_batch_size=batch_size,\n",
668
+ " num_train_epochs=5,\n",
669
+ " weight_decay=0.01,\n",
670
+ " load_best_model_at_end=True,\n",
671
+ " metric_for_best_model=metric_name,\n",
672
+ " push_to_hub=False\n",
673
+ ")"
674
+ ]
675
+ },
676
+ {
677
+ "cell_type": "code",
678
+ "execution_count": 21,
679
+ "metadata": {},
680
+ "outputs": [],
681
+ "source": [
682
+ "from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n",
683
+ "from transformers import EvalPrediction\n",
684
+ "import torch\n",
685
+ " \n",
686
+ "# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n",
687
+ "def multi_label_metrics(predictions, labels, threshold=0.5):\n",
688
+ " sigmoid = torch.nn.Sigmoid()\n",
689
+ " probs = sigmoid(torch.Tensor(predictions))\n",
690
+ " y_pred = np.zeros(probs.shape)\n",
691
+ " y_pred[np.where(probs >= threshold)] = 1\n",
692
+ " y_true = labels\n",
693
+ " f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n",
694
+ " roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n",
695
+ " accuracy = accuracy_score(y_true, y_pred)\n",
696
+ " metrics = {'f1': f1_micro_average,\n",
697
+ " 'roc_auc': roc_auc,\n",
698
+ " 'accuracy': accuracy}\n",
699
+ " return metrics\n",
700
+ "\n",
701
+ "def compute_metrics(p: EvalPrediction):\n",
702
+ " preds = p.predictions[0] if isinstance(p.predictions, \n",
703
+ " tuple) else p.predictions\n",
704
+ " result = multi_label_metrics(\n",
705
+ " predictions=preds, \n",
706
+ " labels=p.label_ids)\n",
707
+ " return result"
708
+ ]
709
+ },
710
+ {
711
+ "cell_type": "code",
712
+ "execution_count": 22,
713
+ "metadata": {},
714
+ "outputs": [
715
+ {
716
+ "data": {
717
+ "text/plain": [
718
+ "tensor([1.3057e-01, 3.9861e+02, 2.3602e+01, 2.2549e+00, 4.2708e+00, 5.7584e+00,\n",
719
+ " 8.3430e+01, 4.8369e-01], device='cuda:0')"
720
+ ]
721
+ },
722
+ "execution_count": 22,
723
+ "metadata": {},
724
+ "output_type": "execute_result"
725
+ }
726
+ ],
727
+ "source": [
728
+ "pos_weight=torch.tensor([\n",
729
+ " len(y_train) / category_to_count[i] / len(category_to_count) for i in range(len(category_to_count))\n",
730
+ "]).to(DEVICE)\n",
731
+ "compute_loss_func_ = BCEWithLogitsLoss(pos_weight=pos_weight)\n",
732
+ "\n",
733
+ "# Example of custom trainer is taken from https://medium.com/deeplearningmadeeasy/how-to-use-a-custom-loss-with-hugging-face-fc9a1f91b39b\n",
734
+ "class CustomTrainer(Trainer):\n",
735
+ " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
736
+ " labels = inputs.pop(\"labels\")\n",
737
+ " outputs = model(**inputs)\n",
738
+ " logits = outputs.logits\n",
739
+ " loss = compute_loss_func_(logits, labels)\n",
740
+ " return (loss, outputs) if return_outputs else loss\n",
741
+ "\n",
742
+ "pos_weight"
743
+ ]
744
+ },
745
+ {
746
+ "cell_type": "code",
747
+ "execution_count": 23,
748
+ "metadata": {},
749
+ "outputs": [
750
+ {
751
+ "name": "stderr",
752
+ "output_type": "stream",
753
+ "text": [
754
+ "/tmp/ipykernel_831875/1711637572.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `CustomTrainer.__init__`. Use `processing_class` instead.\n",
755
+ " trainer = CustomTrainer(\n"
756
+ ]
757
+ }
758
+ ],
759
+ "source": [
760
+ "trainer = CustomTrainer(\n",
761
+ " model,\n",
762
+ " args,\n",
763
+ " train_dataset=train_dataset,\n",
764
+ " eval_dataset=val_dataset,\n",
765
+ " tokenizer=tokenizer,\n",
766
+ " compute_metrics=compute_metrics\n",
767
+ ")"
768
+ ]
769
+ },
770
+ {
771
+ "cell_type": "code",
772
+ "execution_count": 24,
773
+ "metadata": {},
774
+ "outputs": [
775
+ {
776
+ "data": {
777
+ "text/html": [
778
+ "\n",
779
+ " <div>\n",
780
+ " \n",
781
+ " <progress value='17940' max='17940' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
782
+ " [17940/17940 1:04:03, Epoch 5/5]\n",
783
+ " </div>\n",
784
+ " <table border=\"1\" class=\"dataframe\">\n",
785
+ " <thead>\n",
786
+ " <tr style=\"text-align: left;\">\n",
787
+ " <th>Epoch</th>\n",
788
+ " <th>Training Loss</th>\n",
789
+ " <th>Validation Loss</th>\n",
790
+ " <th>F1</th>\n",
791
+ " <th>Roc Auc</th>\n",
792
+ " <th>Accuracy</th>\n",
793
+ " </tr>\n",
794
+ " </thead>\n",
795
+ " <tbody>\n",
796
+ " <tr>\n",
797
+ " <td>1</td>\n",
798
+ " <td>0.499100</td>\n",
799
+ " <td>0.387039</td>\n",
800
+ " <td>0.826313</td>\n",
801
+ " <td>0.866803</td>\n",
802
+ " <td>0.681951</td>\n",
803
+ " </tr>\n",
804
+ " <tr>\n",
805
+ " <td>2</td>\n",
806
+ " <td>0.295300</td>\n",
807
+ " <td>0.386058</td>\n",
808
+ " <td>0.837538</td>\n",
809
+ " <td>0.871269</td>\n",
810
+ " <td>0.694878</td>\n",
811
+ " </tr>\n",
812
+ " <tr>\n",
813
+ " <td>3</td>\n",
814
+ " <td>0.383900</td>\n",
815
+ " <td>0.354087</td>\n",
816
+ " <td>0.848541</td>\n",
817
+ " <td>0.887079</td>\n",
818
+ " <td>0.705122</td>\n",
819
+ " </tr>\n",
820
+ " <tr>\n",
821
+ " <td>4</td>\n",
822
+ " <td>0.167500</td>\n",
823
+ " <td>0.375260</td>\n",
824
+ " <td>0.850880</td>\n",
825
+ " <td>0.888822</td>\n",
826
+ " <td>0.707561</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <td>5</td>\n",
830
+ " <td>0.282100</td>\n",
831
+ " <td>0.409332</td>\n",
832
+ " <td>0.857789</td>\n",
833
+ " <td>0.898414</td>\n",
834
+ " <td>0.712927</td>\n",
835
+ " </tr>\n",
836
+ " </tbody>\n",
837
+ "</table><p>"
838
+ ],
839
+ "text/plain": [
840
+ "<IPython.core.display.HTML object>"
841
+ ]
842
+ },
843
+ "metadata": {},
844
+ "output_type": "display_data"
845
+ },
846
+ {
847
+ "data": {
848
+ "text/plain": [
849
+ "TrainOutput(global_step=17940, training_loss=0.3050233064819472, metrics={'train_runtime': 3844.3438, 'train_samples_per_second': 37.328, 'train_steps_per_second': 4.667, 'total_flos': 3431411601408000.0, 'train_loss': 0.3050233064819472, 'epoch': 5.0})"
850
+ ]
851
+ },
852
+ "execution_count": 24,
853
+ "metadata": {},
854
+ "output_type": "execute_result"
855
+ }
856
+ ],
857
+ "source": [
858
+ "trainer.train()"
859
+ ]
860
+ },
861
+ {
862
+ "cell_type": "code",
863
+ "execution_count": 25,
864
+ "metadata": {},
865
+ "outputs": [
866
+ {
867
+ "data": {
868
+ "text/html": [
869
+ "\n",
870
+ " <div>\n",
871
+ " \n",
872
+ " <progress value='1' max='1025' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
873
+ " [ 1/1025 : < :]\n",
874
+ " </div>\n",
875
+ " "
876
+ ],
877
+ "text/plain": [
878
+ "<IPython.core.display.HTML object>"
879
+ ]
880
+ },
881
+ "metadata": {},
882
+ "output_type": "display_data"
883
+ },
884
+ {
885
+ "data": {
886
+ "text/plain": [
887
+ "{'eval_loss': 0.40933170914649963,\n",
888
+ " 'eval_f1': 0.8577886788823161,\n",
889
+ " 'eval_roc_auc': 0.8984138467714379,\n",
890
+ " 'eval_accuracy': 0.7129268292682926,\n",
891
+ " 'eval_runtime': 69.0762,\n",
892
+ " 'eval_samples_per_second': 118.71,\n",
893
+ " 'eval_steps_per_second': 14.839,\n",
894
+ " 'epoch': 5.0}"
895
+ ]
896
+ },
897
+ "execution_count": 25,
898
+ "metadata": {},
899
+ "output_type": "execute_result"
900
+ }
901
+ ],
902
+ "source": [
903
+ "trainer.evaluate(eval_dataset=val_dataset)"
904
+ ]
905
+ },
906
+ {
907
+ "cell_type": "code",
908
+ "execution_count": 26,
909
+ "metadata": {},
910
+ "outputs": [
911
+ {
912
+ "data": {
913
+ "text/plain": [
914
+ "{'eval_loss': 0.5079951882362366,\n",
915
+ " 'eval_f1': 0.8536538088776895,\n",
916
+ " 'eval_roc_auc': 0.8970809313634953,\n",
917
+ " 'eval_accuracy': 0.708780487804878,\n",
918
+ " 'eval_runtime': 34.407,\n",
919
+ " 'eval_samples_per_second': 119.162,\n",
920
+ " 'eval_steps_per_second': 14.91,\n",
921
+ " 'epoch': 5.0}"
922
+ ]
923
+ },
924
+ "execution_count": 26,
925
+ "metadata": {},
926
+ "output_type": "execute_result"
927
+ }
928
+ ],
929
+ "source": [
930
+ "trainer.evaluate(eval_dataset=test_dataset)"
931
+ ]
932
+ },
933
+ {
934
+ "cell_type": "markdown",
935
+ "metadata": {},
936
+ "source": [
937
+ "Исходная задача у нас звучала как \"хотим увидеть топ-95%* тематик, отсортированных по убыванию вероятности\", где под тематиками имелись ввиду категории (физика, биология и так далее)\n",
938
+ "\n",
939
+ "Будем делать следующее:\n",
940
+ "- наша модель выдает логиты категорий\n",
941
+ "- посчитаем с их помощью вероятность категорий, считая их сумму равной 1 (хотя на самом деле категорий может быть несколько)\n",
942
+ "- выведем требуемые топ-95% тематик"
943
+ ]
944
+ },
945
+ {
946
+ "cell_type": "code",
947
+ "execution_count": 29,
948
+ "metadata": {},
949
+ "outputs": [],
950
+ "source": [
951
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
952
+ " f\"train-{USED_MODEL}/checkpoint-10764\", \n",
953
+ " problem_type=\"multi_label_classification\", \n",
954
+ " num_labels=len(category_to_index),\n",
955
+ " id2label=index_to_category,\n",
956
+ " label2id=category_to_index\n",
957
+ ").to(DEVICE)"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": 30,
963
+ "metadata": {},
964
+ "outputs": [
965
+ {
966
+ "data": {
967
+ "text/plain": [
968
+ "SequenceClassifierOutput(loss=None, logits=tensor([[ 3.5393, -7.6223, -5.9721, -0.6268, -3.4508, -5.2609, -5.8817, -3.5099]],\n",
969
+ " device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
970
+ ]
971
+ },
972
+ "execution_count": 30,
973
+ "metadata": {},
974
+ "output_type": "execute_result"
975
+ }
976
+ ],
977
+ "source": [
978
+ "model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function('Maths is cool $ In our article we prove that maths is the coolest subject at school').items()})"
979
+ ]
980
+ },
981
+ {
982
+ "cell_type": "code",
983
+ "execution_count": 31,
984
+ "metadata": {},
985
+ "outputs": [],
986
+ "source": [
987
+ "@torch.no_grad\n",
988
+ "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
989
+ " text = f'{title} $ {summary}'\n",
990
+ " category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
991
+ " sigmoid = torch.nn.Sigmoid()\n",
992
+ " category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n",
993
+ " category_probs /= category_probs.sum()\n",
994
+ " category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}\n",
995
+ " for index in range(len(index_to_category)):\n",
996
+ " category_probs_dict[index_to_category[index]] += float(category_probs[index])\n",
997
+ " return category_probs_dict"
998
+ ]
999
+ },
1000
+ {
1001
+ "cell_type": "code",
1002
+ "execution_count": 33,
1003
+ "metadata": {},
1004
+ "outputs": [],
1005
+ "source": [
1006
+ "def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:\n",
1007
+ " current_p = 0\n",
1008
+ " probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]\n",
1009
+ " current_index = 0\n",
1010
+ " answer = []\n",
1011
+ " while current_p <= target_probability:\n",
1012
+ " current_p += probs_list[current_index][0]\n",
1013
+ " if not print_probabilities:\n",
1014
+ " answer.append(probs_list[current_index][1])\n",
1015
+ " else:\n",
1016
+ " answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')\n",
1017
+ " current_index += 1\n",
1018
+ " if current_index >= len(probs_list):\n",
1019
+ " break\n",
1020
+ " return answer"
1021
+ ]
1022
+ },
1023
+ {
1024
+ "cell_type": "markdown",
1025
+ "metadata": {},
1026
+ "source": [
1027
+ "Сохраняем модель, чтобы потом можно было её использовать в huggingface space"
1028
+ ]
1029
+ },
1030
+ {
1031
+ "cell_type": "code",
1032
+ "execution_count": 35,
1033
+ "metadata": {},
1034
+ "outputs": [
1035
+ {
1036
+ "name": "stderr",
1037
+ "output_type": "stream",
1038
+ "text": [
1039
+ "model.safetensors: 100%|██████████| 46.8M/46.8M [00:06<00:00, 7.25MB/s] \n"
1040
+ ]
1041
+ },
1042
+ {
1043
+ "data": {
1044
+ "text/plain": [
1045
+ "CommitInfo(commit_url='https://huggingface.co/bumchik2/train-albert-base-v2-tags-classification/commit/bf0db75a370cbbfeed52bcb5760c9adfea198bf9', commit_message='Upload AlbertForSequenceClassification', commit_description='', oid='bf0db75a370cbbfeed52bcb5760c9adfea198bf9', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bumchik2/train-albert-base-v2-tags-classification', endpoint='https://huggingface.co', repo_type='model', repo_id='bumchik2/train-albert-base-v2-tags-classification'), pr_revision=None, pr_num=None)"
1046
+ ]
1047
+ },
1048
+ "execution_count": 35,
1049
+ "metadata": {},
1050
+ "output_type": "execute_result"
1051
+ }
1052
+ ],
1053
+ "source": [
1054
+ "model.push_to_hub(f\"bumchik2/train-{USED_MODEL}-tags-classification\")"
1055
+ ]
1056
+ },
1057
+ {
1058
+ "cell_type": "markdown",
1059
+ "metadata": {},
1060
+ "source": [
1061
+ "Теперь я могу загружать свою модель оттуда"
1062
+ ]
1063
+ },
1064
+ {
1065
+ "cell_type": "code",
1066
+ "execution_count": 36,
1067
+ "metadata": {},
1068
+ "outputs": [],
1069
+ "source": [
1070
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
1071
+ " f\"bumchik2/train-{USED_MODEL}-tags-classification\", \n",
1072
+ " problem_type=\"multi_label_classification\", \n",
1073
+ " num_labels=len(category_to_index),\n",
1074
+ " id2label=index_to_category,\n",
1075
+ " label2id=category_to_index\n",
1076
+ ").to(DEVICE)"
1077
+ ]
1078
+ },
1079
+ {
1080
+ "cell_type": "markdown",
1081
+ "metadata": {},
1082
+ "source": [
1083
+ "Протестируем на нескольких реальных примерах:"
1084
+ ]
1085
+ },
1086
+ {
1087
+ "cell_type": "code",
1088
+ "execution_count": 39,
1089
+ "metadata": {},
1090
+ "outputs": [
1091
+ {
1092
+ "data": {
1093
+ "text/plain": [
1094
+ "['Quantitative Biology (0.4282848834991455)',\n",
1095
+ " 'Statistics (0.34262675046920776)',\n",
1096
+ " 'Computer Science (0.14248277246952057)',\n",
1097
+ " 'Physics (0.034869205206632614)',\n",
1098
+ " 'Mathematics (0.029306704178452492)']"
1099
+ ]
1100
+ },
1101
+ "execution_count": 39,
1102
+ "metadata": {},
1103
+ "output_type": "execute_result"
1104
+ }
1105
+ ],
1106
+ "source": [
1107
+ "# правильный ответ Quantitative Biology\n",
1108
+ "get_most_probable_keys(\n",
1109
+ " probs_dict=get_category_probs_dict(\n",
1110
+ " model=model,\n",
1111
+ " title='Simulating cell populations with explicit cell cycle length -- implications to cell cycle dependent tumour therapy',\n",
1112
+ " summary='In this study, we present a stochastic simulation model designed to explicitly incorporate cell cycle length, overcoming limitations associated with classical compartmental models. Our approach employs a delay mechanism to represent the cell cycle, allowing the use of arbitrary distributions for cell cycle lengths. We demonstrate the feasibility of our model by fitting it to experimental data from melanoma cell lines previously studied by Vittadello et al. Notably, our model successfully replicates experimentally observed synchronization phenomena that multi-stage models could not adequately explain. By using a gamma distribution to model cell cycle lengths, we achieved excellent agreement between our simulations and empirical data, while significantly reducing computational complexity and parameter estimation challenges inherent in multi-stage approaches. Our results highlight the importance of explicitly incorporating cell cycle lengths in modeling cell populations, with potential implications for optimizing cell cycle-dependent tumor therapies.'\n",
1113
+ " ),\n",
1114
+ " target_probability=0.95,\n",
1115
+ " print_probabilities=True\n",
1116
+ ")"
1117
+ ]
1118
+ },
1119
+ {
1120
+ "cell_type": "code",
1121
+ "execution_count": 40,
1122
+ "metadata": {},
1123
+ "outputs": [
1124
+ {
1125
+ "data": {
1126
+ "text/plain": [
1127
+ "['Computer Science (0.48184847831726074)',\n",
1128
+ " 'Physics (0.41395917534828186)',\n",
1129
+ " 'Statistics (0.029943009838461876)',\n",
1130
+ " 'Electrical Engineering and Systems Science (0.028027774766087532)']"
1131
+ ]
1132
+ },
1133
+ "execution_count": 40,
1134
+ "metadata": {},
1135
+ "output_type": "execute_result"
1136
+ }
1137
+ ],
1138
+ "source": [
1139
+ "# правильный ответ Physics\n",
1140
+ "get_most_probable_keys(\n",
1141
+ " probs_dict=get_category_probs_dict(\n",
1142
+ " model=model,\n",
1143
+ " title='Performance Improvement of LTS Undulators for Synchrotron Light Sources',\n",
1144
+ " summary='The joint expertise of ANL and FNAL has led to the production of Nb3Sn undulator magnets in operation in the ANL Advanced Photon Source (APS). These magnets showed performance reproducibility close to the short sample limit, and a design field increase of 20% at 820A. However, the long training did not allow obtaining the expected 50% increase of the on-axis magnetic field with respect to the ~1 T produced at 450 A current in the ANL NbTi undulator. To address this, 10-pole long undulator prototypes were fabricated, and CTD-101K was replaced as impregnation material with TELENE, an organic olefin-based thermosetting dicyclopentadiene resin produced by RIMTEC Corporation, Japan. Training and magnet retraining after a thermal cycle were nearly eliminated, with only a couple of quenches needed before reaching short sample limit at over 1,100 A. TELENE will enable operation of Nb3Sn undulators much closer to their short sample limit, expanding the energy range and brightness intensity of light sources. TELENE is Co-60 gamma radiation resistant up to 7-8 MGy, and therefore already applicable to impregnate planar, helical and universal devices operating in lower radiation environments than high energy colliders.'\n",
1145
+ " ),\n",
1146
+ " target_probability=0.95,\n",
1147
+ " print_probabilities=True\n",
1148
+ ")"
1149
+ ]
1150
+ },
1151
+ {
1152
+ "cell_type": "code",
1153
+ "execution_count": null,
1154
+ "metadata": {},
1155
+ "outputs": [],
1156
+ "source": []
1157
+ }
1158
+ ],
1159
+ "metadata": {
1160
+ "kernelspec": {
1161
+ "display_name": "Tricks",
1162
+ "language": "python",
1163
+ "name": "python3"
1164
+ },
1165
+ "language_info": {
1166
+ "codemirror_mode": {
1167
+ "name": "ipython",
1168
+ "version": 3
1169
+ },
1170
+ "file_extension": ".py",
1171
+ "mimetype": "text/x-python",
1172
+ "name": "python",
1173
+ "nbconvert_exporter": "python",
1174
+ "pygments_lexer": "ipython3",
1175
+ "version": "3.10.12"
1176
+ }
1177
+ },
1178
+ "nbformat": 4,
1179
+ "nbformat_minor": 2
1180
+ }
notebooks/distilbert_base_cased_main.ipynb ADDED
@@ -0,0 +1,1187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "4.50.0\n"
21
+ ]
22
+ },
23
+ {
24
+ "data": {
25
+ "text/plain": [
26
+ "device(type='cuda')"
27
+ ]
28
+ },
29
+ "execution_count": 1,
30
+ "metadata": {},
31
+ "output_type": "execute_result"
32
+ }
33
+ ],
34
+ "source": [
35
+ "from transformers import pipeline\n",
36
+ "import json\n",
37
+ "import pandas as pd\n",
38
+ "from sklearn.model_selection import train_test_split\n",
39
+ "from transformers import DistilBertTokenizer\n",
40
+ "from tqdm import tqdm\n",
41
+ "import re\n",
42
+ "from datasets import Dataset\n",
43
+ "from transformers import AutoModelForSequenceClassification\n",
44
+ "import torch\n",
45
+ "import numpy as np\n",
46
+ "from typing import Dict\n",
47
+ "from transformers import AutoModel\n",
48
+ "from torch.nn import BCEWithLogitsLoss\n",
49
+ "from typing import List\n",
50
+ "from transformers import TrainingArguments, Trainer\n",
51
+ "from collections import defaultdict\n",
52
+ "\n",
53
+ "from transformers import __version__ as transformers_version\n",
54
+ "print(transformers_version)\n",
55
+ "\n",
56
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
57
+ "DEVICE"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 2,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "USED_MODEL = \"distilbert-base-cased\""
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "def read_json(json_filename):\n",
76
+ " with open(json_filename, 'r') as f:\n",
77
+ " return json.loads(f.read())\n",
78
+ "\n",
79
+ "\n",
80
+ "def save_json(json_object, json_filename, indent=4):\n",
81
+ " with open(json_filename, 'w') as f:\n",
82
+ " json.dump(json_object, f, separators=(',', ':'), indent=indent)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "**Данные берем отсюда: https://www.kaggle.com/datasets/neelshah18/arxivdataset**"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 4,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "arxiv_data = read_json('arxivData.json')"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 5,
104
+ "metadata": {},
105
+ "outputs": [
106
+ {
107
+ "data": {
108
+ "text/plain": [
109
+ "{'author': \"[{'name': 'Ahmed Osman'}, {'name': 'Wojciech Samek'}]\",\n",
110
+ " 'day': 1,\n",
111
+ " 'id': '1802.00209v1',\n",
112
+ " 'link': \"[{'rel': 'alternate', 'href': 'http://arxiv.org/abs/1802.00209v1', 'type': 'text/html'}, {'rel': 'related', 'href': 'http://arxiv.org/pdf/1802.00209v1', 'type': 'application/pdf', 'title': 'pdf'}]\",\n",
113
+ " 'month': 2,\n",
114
+ " 'summary': 'We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.',\n",
115
+ " 'tag': \"[{'term': 'cs.AI', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CL', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CV', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.NE', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'stat.ML', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}]\",\n",
116
+ " 'title': 'Dual Recurrent Attention Units for Visual Question Answering',\n",
117
+ " 'year': 2018}"
118
+ ]
119
+ },
120
+ "execution_count": 5,
121
+ "metadata": {},
122
+ "output_type": "execute_result"
123
+ }
124
+ ],
125
+ "source": [
126
+ "arxiv_data[0]"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "metadata": {},
132
+ "source": [
133
+ "**Хотим по названию статьи + abstract выдавать наиболее вероятную тематику статьи, скажем, физика, биология или computer science** "
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 6,
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "155\n"
146
+ ]
147
+ },
148
+ {
149
+ "data": {
150
+ "text/html": [
151
+ "<div>\n",
152
+ "<style scoped>\n",
153
+ " .dataframe tbody tr th:only-of-type {\n",
154
+ " vertical-align: middle;\n",
155
+ " }\n",
156
+ "\n",
157
+ " .dataframe tbody tr th {\n",
158
+ " vertical-align: top;\n",
159
+ " }\n",
160
+ "\n",
161
+ " .dataframe thead th {\n",
162
+ " text-align: right;\n",
163
+ " }\n",
164
+ "</style>\n",
165
+ "<table border=\"1\" class=\"dataframe\">\n",
166
+ " <thead>\n",
167
+ " <tr style=\"text-align: right;\">\n",
168
+ " <th></th>\n",
169
+ " <th>tag</th>\n",
170
+ " <th>topic</th>\n",
171
+ " <th>category</th>\n",
172
+ " </tr>\n",
173
+ " </thead>\n",
174
+ " <tbody>\n",
175
+ " <tr>\n",
176
+ " <th>0</th>\n",
177
+ " <td>cs.AI</td>\n",
178
+ " <td>Artificial Intelligence</td>\n",
179
+ " <td>Computer Science</td>\n",
180
+ " </tr>\n",
181
+ " <tr>\n",
182
+ " <th>1</th>\n",
183
+ " <td>cs.AR</td>\n",
184
+ " <td>Hardware Architecture</td>\n",
185
+ " <td>Computer Science</td>\n",
186
+ " </tr>\n",
187
+ " <tr>\n",
188
+ " <th>2</th>\n",
189
+ " <td>cs.CC</td>\n",
190
+ " <td>Computational Complexity</td>\n",
191
+ " <td>Computer Science</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>3</th>\n",
195
+ " <td>cs.CE</td>\n",
196
+ " <td>Computational Engineering, Finance, and Science</td>\n",
197
+ " <td>Computer Science</td>\n",
198
+ " </tr>\n",
199
+ " <tr>\n",
200
+ " <th>4</th>\n",
201
+ " <td>cs.CG</td>\n",
202
+ " <td>Computational Geometry</td>\n",
203
+ " <td>Computer Science</td>\n",
204
+ " </tr>\n",
205
+ " </tbody>\n",
206
+ "</table>\n",
207
+ "</div>"
208
+ ],
209
+ "text/plain": [
210
+ " tag topic category\n",
211
+ "0 cs.AI Artificial Intelligence Computer Science\n",
212
+ "1 cs.AR Hardware Architecture Computer Science\n",
213
+ "2 cs.CC Computational Complexity Computer Science\n",
214
+ "3 cs.CE Computational Engineering, Finance, and Science Computer Science\n",
215
+ "4 cs.CG Computational Geometry Computer Science"
216
+ ]
217
+ },
218
+ "execution_count": 6,
219
+ "metadata": {},
220
+ "output_type": "execute_result"
221
+ }
222
+ ],
223
+ "source": [
224
+ "arxiv_topics_df = pd.read_csv('arxiv_topics.csv')\n",
225
+ "print(len(arxiv_topics_df))\n",
226
+ "arxiv_topics_df.head(5)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 7,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "category_to_index = {}\n",
236
+ "tag_to_category = {}\n",
237
+ "current_index = 0\n",
238
+ "for i, row in arxiv_topics_df.iterrows():\n",
239
+ " category = row['category']\n",
240
+ " if category not in category_to_index:\n",
241
+ " category_to_index[category] = current_index\n",
242
+ " current_index += 1\n",
243
+ " tag_to_category[row['tag']] = row['category']\n",
244
+ "index_to_category = {value: key for key, value in category_to_index.items()}"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {},
250
+ "source": [
251
+ "**Готовим данные к обучению**"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": 8,
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "name": "stderr",
261
+ "output_type": "stream",
262
+ "text": [
263
+ " 0%| | 0/41000 [00:00<?, ?it/s]"
264
+ ]
265
+ },
266
+ {
267
+ "name": "stderr",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "100%|██████████| 41000/41000 [00:01<00:00, 34214.55it/s]"
271
+ ]
272
+ },
273
+ {
274
+ "name": "stdout",
275
+ "output_type": "stream",
276
+ "text": [
277
+ "Среднее число категорий в одной статье: 1.3301219512195122\n",
278
+ "Среднее число тегов в одной статье: 1.8489024390243902\n"
279
+ ]
280
+ },
281
+ {
282
+ "name": "stderr",
283
+ "output_type": "stream",
284
+ "text": [
285
+ "\n"
286
+ ]
287
+ }
288
+ ],
289
+ "source": [
290
+ "def is_valid_tag(tag: str) -> bool:\n",
291
+ " return tag in tag_to_category\n",
292
+ "\n",
293
+ "total_categories_count = 0\n",
294
+ "total_tags_count = 0\n",
295
+ "records = []\n",
296
+ "for arxiv_record in tqdm(arxiv_data):\n",
297
+ " record = {\n",
298
+ " 'title': arxiv_record['title'],\n",
299
+ " 'summary': arxiv_record['summary'],\n",
300
+ " 'title_and_summary': arxiv_record['title'] + ' $ ' + arxiv_record['summary'],\n",
301
+ " 'tags': [current_tag['term'] for current_tag in eval(arxiv_record['tag']) if is_valid_tag(current_tag['term'])]\n",
302
+ " }\n",
303
+ " categories = set(tag_to_category[tag] for tag in record['tags'])\n",
304
+ " total_categories_count += len(categories)\n",
305
+ " total_tags_count += len(record['tags'])\n",
306
+ " record['categories_indices'] = list(set([category_to_index[tag_to_category[tag]] for tag in record['tags']]))\n",
307
+ " assert len(record['tags']) > 0\n",
308
+ " records.append(record)\n",
309
+ "\n",
310
+ "print(f'Среднее число категорий в одной статье: {total_categories_count / len(arxiv_data)}')\n",
311
+ "print(f'Среднее число тегов в одной статье: {total_tags_count / len(arxiv_data)}')"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "markdown",
316
+ "metadata": {},
317
+ "source": [
318
+ "Как видим, перед нами задача мультибинарной классификации.\n",
319
+ "\n",
320
+ "Тегов у одной статьи бывает много, это понятно, но и категорий тоже бывает много. То есть, условно статья может быть посвящена и физике и биологии одновременно.\n",
321
+ "\n",
322
+ "Попробуем обучить модель определять теги - так она потенциально может сохранить в себе больше информации, чем если ее обучить определять категории (которых гораздо меньше)."
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {},
328
+ "source": [
329
+ "**Соединяем title и summary используя символ `$` - он редкий, при этом его знает токенайзер, поэтому не придется с ним дополнительно возиться**"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": 9,
335
+ "metadata": {},
336
+ "outputs": [
337
+ {
338
+ "name": "stdout",
339
+ "output_type": "stream",
340
+ "text": [
341
+ "41000\n"
342
+ ]
343
+ },
344
+ {
345
+ "data": {
346
+ "text/html": [
347
+ "<div>\n",
348
+ "<style scoped>\n",
349
+ " .dataframe tbody tr th:only-of-type {\n",
350
+ " vertical-align: middle;\n",
351
+ " }\n",
352
+ "\n",
353
+ " .dataframe tbody tr th {\n",
354
+ " vertical-align: top;\n",
355
+ " }\n",
356
+ "\n",
357
+ " .dataframe thead th {\n",
358
+ " text-align: right;\n",
359
+ " }\n",
360
+ "</style>\n",
361
+ "<table border=\"1\" class=\"dataframe\">\n",
362
+ " <thead>\n",
363
+ " <tr style=\"text-align: right;\">\n",
364
+ " <th></th>\n",
365
+ " <th>title</th>\n",
366
+ " <th>summary</th>\n",
367
+ " <th>title_and_summary</th>\n",
368
+ " <th>tags</th>\n",
369
+ " <th>categories_indices</th>\n",
370
+ " </tr>\n",
371
+ " </thead>\n",
372
+ " <tbody>\n",
373
+ " <tr>\n",
374
+ " <th>0</th>\n",
375
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
376
+ " <td>We propose an architecture for VQA which utili...</td>\n",
377
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
378
+ " <td>[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML]</td>\n",
379
+ " <td>[0, 7]</td>\n",
380
+ " </tr>\n",
381
+ " <tr>\n",
382
+ " <th>1</th>\n",
383
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
384
+ " <td>Recent approaches based on artificial neural n...</td>\n",
385
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
386
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
387
+ " <td>[0, 7]</td>\n",
388
+ " </tr>\n",
389
+ " <tr>\n",
390
+ " <th>2</th>\n",
391
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
392
+ " <td>We introduce the multiresolution recurrent neu...</td>\n",
393
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
394
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
395
+ " <td>[0, 7]</td>\n",
396
+ " </tr>\n",
397
+ " <tr>\n",
398
+ " <th>3</th>\n",
399
+ " <td>Learning what to share between loosely related...</td>\n",
400
+ " <td>Multi-task learning is motivated by the observ...</td>\n",
401
+ " <td>Learning what to share between loosely related...</td>\n",
402
+ " <td>[stat.ML, cs.AI, cs.CL, cs.LG, cs.NE]</td>\n",
403
+ " <td>[0, 7]</td>\n",
404
+ " </tr>\n",
405
+ " <tr>\n",
406
+ " <th>4</th>\n",
407
+ " <td>A Deep Reinforcement Learning Chatbot</td>\n",
408
+ " <td>We present MILABOT: a deep reinforcement learn...</td>\n",
409
+ " <td>A Deep Reinforcement Learning Chatbot $ We pre...</td>\n",
410
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
411
+ " <td>[0, 7]</td>\n",
412
+ " </tr>\n",
413
+ " </tbody>\n",
414
+ "</table>\n",
415
+ "</div>"
416
+ ],
417
+ "text/plain": [
418
+ " title \\\n",
419
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
420
+ "1 Sequential Short-Text Classification with Recu... \n",
421
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
422
+ "3 Learning what to share between loosely related... \n",
423
+ "4 A Deep Reinforcement Learning Chatbot \n",
424
+ "\n",
425
+ " summary \\\n",
426
+ "0 We propose an architecture for VQA which utili... \n",
427
+ "1 Recent approaches based on artificial neural n... \n",
428
+ "2 We introduce the multiresolution recurrent neu... \n",
429
+ "3 Multi-task learning is motivated by the observ... \n",
430
+ "4 We present MILABOT: a deep reinforcement learn... \n",
431
+ "\n",
432
+ " title_and_summary \\\n",
433
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
434
+ "1 Sequential Short-Text Classification with Recu... \n",
435
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
436
+ "3 Learning what to share between loosely related... \n",
437
+ "4 A Deep Reinforcement Learning Chatbot $ We pre... \n",
438
+ "\n",
439
+ " tags categories_indices \n",
440
+ "0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] [0, 7] \n",
441
+ "1 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n",
442
+ "2 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n",
443
+ "3 [stat.ML, cs.AI, cs.CL, cs.LG, cs.NE] [0, 7] \n",
444
+ "4 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] "
445
+ ]
446
+ },
447
+ "execution_count": 9,
448
+ "metadata": {},
449
+ "output_type": "execute_result"
450
+ }
451
+ ],
452
+ "source": [
453
+ "full_data_df = pd.DataFrame(records)\n",
454
+ "print(len(full_data_df))\n",
455
+ "full_data_df.head(5)"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "markdown",
460
+ "metadata": {},
461
+ "source": [
462
+ "**Как видим, Computer science встречается очень часто. А, например, экономика - совсем редко. Значит при обучении экономике логично давать больше вес**"
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": 11,
468
+ "metadata": {},
469
+ "outputs": [],
470
+ "source": [
471
+ "text_data = list(full_data_df['title_and_summary'])\n",
472
+ "categories_indices = list(full_data_df['categories_indices'])"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "code",
477
+ "execution_count": 12,
478
+ "metadata": {},
479
+ "outputs": [
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "28700 8200 4100\n"
485
+ ]
486
+ }
487
+ ],
488
+ "source": [
489
+ "X_train_val, X_test, y_train_val, y_test = train_test_split(text_data, categories_indices, test_size=0.1, random_state=42)\n",
490
+ "X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=2/9, random_state=42)\n",
491
+ "print(len(X_train), len(X_val), len(X_test))\n",
492
+ "# Train is 70%, val is 20%, test is 10%"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "markdown",
497
+ "metadata": {},
498
+ "source": [
499
+ "Посмотрим на распределение категорий в тренировочной выборке"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": 40,
505
+ "metadata": {},
506
+ "outputs": [
507
+ {
508
+ "name": "stdout",
509
+ "output_type": "stream",
510
+ "text": [
511
+ "{0: 27475, 3: 1591, 7: 7417, 5: 623, 2: 152, 4: 840, 6: 43, 1: 9}\n"
512
+ ]
513
+ }
514
+ ],
515
+ "source": [
516
+ "category_to_count = defaultdict(int)\n",
517
+ "for row in y_train:\n",
518
+ " for category in row:\n",
519
+ " category_to_count[category] += 1\n",
520
+ "category_to_count = dict(category_to_count)\n",
521
+ "print(category_to_count)"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": 11,
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "tokenizer = DistilBertTokenizer.from_pretrained(USED_MODEL)\n",
531
+ "def tokenize_function(text):\n",
532
+ " return tokenizer(text, padding=\"max_length\", truncation=True)"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "execution_count": 14,
538
+ "metadata": {},
539
+ "outputs": [
540
+ {
541
+ "name": "stdout",
542
+ "output_type": "stream",
543
+ "text": [
544
+ "<class 'transformers.tokenization_utils_base.BatchEncoding'>\n",
545
+ "['_MutableMapping__marker', '__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__or__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_encodings', '_n_sequences', 'char_to_token', 'char_to_word', 'clear', 'convert_to_tensors', 'copy', 'data', 'encodings', 'fromkeys', 'get', 'is_fast', 'items', 'keys', 'n_sequences', 'pop', 'popitem', 'sequence_ids', 'setdefault', 'to', 'token_to_chars', 'token_to_sequence', 'token_to_word', 'tokens', 'update', 'values', 'word_ids', 'word_to_chars', 'word_to_tokens', 'words']\n",
546
+ "2\n"
547
+ ]
548
+ }
549
+ ],
550
+ "source": [
551
+ "train_encodings = tokenize_function(X_train)\n",
552
+ "val_encodings = tokenize_function(X_val)\n",
553
+ "test_encodings = tokenize_function(X_test)\n",
554
+ "\n",
555
+ "print(type(train_encodings))\n",
556
+ "print(dir(train_encodings))\n",
557
+ "print(len(train_encodings))"
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "code",
562
+ "execution_count": 15,
563
+ "metadata": {},
564
+ "outputs": [
565
+ {
566
+ "name": "stderr",
567
+ "output_type": "stream",
568
+ "text": [
569
+ "100%|██████████| 28700/28700 [00:00<00:00, 521869.58it/s]\n",
570
+ "100%|██████████| 8200/8200 [00:00<00:00, 529420.80it/s]\n",
571
+ "100%|██████████| 4100/4100 [00:00<00:00, 516369.29it/s]\n"
572
+ ]
573
+ }
574
+ ],
575
+ "source": [
576
+ "def get_labels(y: List[List[int]]):\n",
577
+ " labels = np.zeros((len(y), len(category_to_index)))\n",
578
+ " for i in tqdm(range(len(y))):\n",
579
+ " labels[i, y[i]] = 1\n",
580
+ " return labels.tolist()\n",
581
+ "\n",
582
+ "labels_train = get_labels(y_train)\n",
583
+ "labels_val = get_labels(y_val)\n",
584
+ "labels_test = get_labels(y_test)"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 16,
590
+ "metadata": {},
591
+ "outputs": [],
592
+ "source": [
593
+ "train_encodings['labels'] = labels_train\n",
594
+ "val_encodings['labels'] = labels_val\n",
595
+ "test_encodings['labels'] = labels_test"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "markdown",
600
+ "metadata": {},
601
+ "source": [
602
+ "**Я использовал пример отсюда чтобы понимать, какой нужен формат данных https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb**"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": 17,
608
+ "metadata": {},
609
+ "outputs": [],
610
+ "source": [
611
+ "train_dataset = Dataset.from_dict(train_encodings)\n",
612
+ "val_dataset = Dataset.from_dict(val_encodings)\n",
613
+ "test_dataset = Dataset.from_dict(test_encodings)\n",
614
+ "\n",
615
+ "train_dataset.set_format(\"torch\")\n",
616
+ "val_dataset.set_format(\"torch\")\n",
617
+ "test_dataset.set_format(\"torch\")"
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": 18,
623
+ "metadata": {},
624
+ "outputs": [
625
+ {
626
+ "name": "stderr",
627
+ "output_type": "stream",
628
+ "text": [
629
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
630
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
631
+ ]
632
+ }
633
+ ],
634
+ "source": [
635
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
636
+ " USED_MODEL, \n",
637
+ " problem_type=\"multi_label_classification\", \n",
638
+ " num_labels=len(category_to_index),\n",
639
+ " id2label=index_to_category,\n",
640
+ " label2id=category_to_index\n",
641
+ ")"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": 19,
647
+ "metadata": {},
648
+ "outputs": [],
649
+ "source": [
650
+ "batch_size = 8\n",
651
+ "metric_name = \"f1\""
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": 20,
657
+ "metadata": {},
658
+ "outputs": [
659
+ {
660
+ "name": "stderr",
661
+ "output_type": "stream",
662
+ "text": [
663
+ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
664
+ " warnings.warn(\n"
665
+ ]
666
+ }
667
+ ],
668
+ "source": [
669
+ "args = TrainingArguments(\n",
670
+ " output_dir=f'train-{USED_MODEL}',\n",
671
+ " evaluation_strategy=\"epoch\",\n",
672
+ " save_strategy=\"epoch\",\n",
673
+ " learning_rate=2e-5,\n",
674
+ " per_device_train_batch_size=batch_size,\n",
675
+ " per_device_eval_batch_size=batch_size,\n",
676
+ " num_train_epochs=5,\n",
677
+ " weight_decay=0.01,\n",
678
+ " load_best_model_at_end=True,\n",
679
+ " metric_for_best_model=metric_name,\n",
680
+ " push_to_hub=False\n",
681
+ ")"
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": 21,
687
+ "metadata": {},
688
+ "outputs": [],
689
+ "source": [
690
+ "from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n",
691
+ "from transformers import EvalPrediction\n",
692
+ "import torch\n",
693
+ " \n",
694
+ "# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n",
695
+ "def multi_label_metrics(predictions, labels, threshold=0.5):\n",
696
+ " sigmoid = torch.nn.Sigmoid()\n",
697
+ " probs = sigmoid(torch.Tensor(predictions))\n",
698
+ " y_pred = np.zeros(probs.shape)\n",
699
+ " y_pred[np.where(probs >= threshold)] = 1\n",
700
+ " y_true = labels\n",
701
+ " f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n",
702
+ " roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n",
703
+ " accuracy = accuracy_score(y_true, y_pred)\n",
704
+ " metrics = {'f1': f1_micro_average,\n",
705
+ " 'roc_auc': roc_auc,\n",
706
+ " 'accuracy': accuracy}\n",
707
+ " return metrics\n",
708
+ "\n",
709
+ "def compute_metrics(p: EvalPrediction):\n",
710
+ " preds = p.predictions[0] if isinstance(p.predictions, \n",
711
+ " tuple) else p.predictions\n",
712
+ " result = multi_label_metrics(\n",
713
+ " predictions=preds, \n",
714
+ " labels=p.label_ids)\n",
715
+ " return result"
716
+ ]
717
+ },
718
+ {
719
+ "cell_type": "code",
720
+ "execution_count": 41,
721
+ "metadata": {},
722
+ "outputs": [
723
+ {
724
+ "data": {
725
+ "text/plain": [
726
+ "tensor([1.3057e-01, 3.9861e+02, 2.3602e+01, 2.2549e+00, 4.2708e+00, 5.7584e+00,\n",
727
+ " 8.3430e+01, 4.8369e-01], device='cuda:0')"
728
+ ]
729
+ },
730
+ "execution_count": 41,
731
+ "metadata": {},
732
+ "output_type": "execute_result"
733
+ }
734
+ ],
735
+ "source": [
736
+ "pos_weight=torch.tensor([\n",
737
+ " len(y_train) / category_to_count[i] / len(category_to_count) for i in range(len(category_to_count))\n",
738
+ "]).to(DEVICE)\n",
739
+ "compute_loss_func_ = BCEWithLogitsLoss(pos_weight=pos_weight)\n",
740
+ "\n",
741
+ "# Example of custom trainer is taken from https://medium.com/deeplearningmadeeasy/how-to-use-a-custom-loss-with-hugging-face-fc9a1f91b39b\n",
742
+ "class CustomTrainer(Trainer):\n",
743
+ " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
744
+ " labels = inputs.pop(\"labels\")\n",
745
+ " outputs = model(**inputs)\n",
746
+ " logits = outputs.logits\n",
747
+ " loss = compute_loss_func_(logits, labels)\n",
748
+ " return (loss, outputs) if return_outputs else loss\n",
749
+ "\n",
750
+ "pos_weight"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "code",
755
+ "execution_count": 42,
756
+ "metadata": {},
757
+ "outputs": [
758
+ {
759
+ "name": "stderr",
760
+ "output_type": "stream",
761
+ "text": [
762
+ "/tmp/ipykernel_725923/1711637572.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `CustomTrainer.__init__`. Use `processing_class` instead.\n",
763
+ " trainer = CustomTrainer(\n"
764
+ ]
765
+ }
766
+ ],
767
+ "source": [
768
+ "trainer = CustomTrainer(\n",
769
+ " model,\n",
770
+ " args,\n",
771
+ " train_dataset=train_dataset,\n",
772
+ " eval_dataset=val_dataset,\n",
773
+ " tokenizer=tokenizer,\n",
774
+ " compute_metrics=compute_metrics\n",
775
+ ")"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": 43,
781
+ "metadata": {},
782
+ "outputs": [
783
+ {
784
+ "data": {
785
+ "text/html": [
786
+ "\n",
787
+ " <div>\n",
788
+ " \n",
789
+ " <progress value='17940' max='17940' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
790
+ " [17940/17940 30:28, Epoch 5/5]\n",
791
+ " </div>\n",
792
+ " <table border=\"1\" class=\"dataframe\">\n",
793
+ " <thead>\n",
794
+ " <tr style=\"text-align: left;\">\n",
795
+ " <th>Epoch</th>\n",
796
+ " <th>Training Loss</th>\n",
797
+ " <th>Validation Loss</th>\n",
798
+ " <th>F1</th>\n",
799
+ " <th>Roc Auc</th>\n",
800
+ " <th>Accuracy</th>\n",
801
+ " </tr>\n",
802
+ " </thead>\n",
803
+ " <tbody>\n",
804
+ " <tr>\n",
805
+ " <td>1</td>\n",
806
+ " <td>0.499800</td>\n",
807
+ " <td>0.367392</td>\n",
808
+ " <td>0.820691</td>\n",
809
+ " <td>0.860402</td>\n",
810
+ " <td>0.678049</td>\n",
811
+ " </tr>\n",
812
+ " <tr>\n",
813
+ " <td>2</td>\n",
814
+ " <td>0.260400</td>\n",
815
+ " <td>0.375715</td>\n",
816
+ " <td>0.834150</td>\n",
817
+ " <td>0.874018</td>\n",
818
+ " <td>0.687805</td>\n",
819
+ " </tr>\n",
820
+ " <tr>\n",
821
+ " <td>3</td>\n",
822
+ " <td>0.278600</td>\n",
823
+ " <td>0.355979</td>\n",
824
+ " <td>0.830082</td>\n",
825
+ " <td>0.877995</td>\n",
826
+ " <td>0.679756</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <td>4</td>\n",
830
+ " <td>0.106900</td>\n",
831
+ " <td>0.435449</td>\n",
832
+ " <td>0.851014</td>\n",
833
+ " <td>0.889915</td>\n",
834
+ " <td>0.707439</td>\n",
835
+ " </tr>\n",
836
+ " <tr>\n",
837
+ " <td>5</td>\n",
838
+ " <td>0.241700</td>\n",
839
+ " <td>0.463859</td>\n",
840
+ " <td>0.853753</td>\n",
841
+ " <td>0.896515</td>\n",
842
+ " <td>0.705976</td>\n",
843
+ " </tr>\n",
844
+ " </tbody>\n",
845
+ "</table><p>"
846
+ ],
847
+ "text/plain": [
848
+ "<IPython.core.display.HTML object>"
849
+ ]
850
+ },
851
+ "metadata": {},
852
+ "output_type": "display_data"
853
+ },
854
+ {
855
+ "data": {
856
+ "text/plain": [
857
+ "TrainOutput(global_step=17940, training_loss=0.2534386833641707, metrics={'train_runtime': 1828.955, 'train_samples_per_second': 78.46, 'train_steps_per_second': 9.809, 'total_flos': 1.9011105705984e+16, 'train_loss': 0.2534386833641707, 'epoch': 5.0})"
858
+ ]
859
+ },
860
+ "execution_count": 43,
861
+ "metadata": {},
862
+ "output_type": "execute_result"
863
+ }
864
+ ],
865
+ "source": [
866
+ "trainer.train()"
867
+ ]
868
+ },
869
+ {
870
+ "cell_type": "code",
871
+ "execution_count": 44,
872
+ "metadata": {},
873
+ "outputs": [
874
+ {
875
+ "data": {
876
+ "text/html": [
877
+ "\n",
878
+ " <div>\n",
879
+ " \n",
880
+ " <progress value='2' max='1025' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
881
+ " [ 2/1025 00:00 < 00:29, 34.60 it/s]\n",
882
+ " </div>\n",
883
+ " "
884
+ ],
885
+ "text/plain": [
886
+ "<IPython.core.display.HTML object>"
887
+ ]
888
+ },
889
+ "metadata": {},
890
+ "output_type": "display_data"
891
+ },
892
+ {
893
+ "data": {
894
+ "text/plain": [
895
+ "{'eval_loss': 0.46385854482650757,\n",
896
+ " 'eval_f1': 0.8537534220258393,\n",
897
+ " 'eval_roc_auc': 0.8965147172006488,\n",
898
+ " 'eval_accuracy': 0.7059756097560975,\n",
899
+ " 'eval_runtime': 30.1947,\n",
900
+ " 'eval_samples_per_second': 271.571,\n",
901
+ " 'eval_steps_per_second': 33.946,\n",
902
+ " 'epoch': 5.0}"
903
+ ]
904
+ },
905
+ "execution_count": 44,
906
+ "metadata": {},
907
+ "output_type": "execute_result"
908
+ }
909
+ ],
910
+ "source": [
911
+ "trainer.evaluate(eval_dataset=val_dataset)"
912
+ ]
913
+ },
914
+ {
915
+ "cell_type": "code",
916
+ "execution_count": 45,
917
+ "metadata": {},
918
+ "outputs": [
919
+ {
920
+ "data": {
921
+ "text/plain": [
922
+ "{'eval_loss': 0.5708531737327576,\n",
923
+ " 'eval_f1': 0.8577275384915222,\n",
924
+ " 'eval_roc_auc': 0.8977722975061319,\n",
925
+ " 'eval_accuracy': 0.718780487804878,\n",
926
+ " 'eval_runtime': 16.2221,\n",
927
+ " 'eval_samples_per_second': 252.742,\n",
928
+ " 'eval_steps_per_second': 31.624,\n",
929
+ " 'epoch': 5.0}"
930
+ ]
931
+ },
932
+ "execution_count": 45,
933
+ "metadata": {},
934
+ "output_type": "execute_result"
935
+ }
936
+ ],
937
+ "source": [
938
+ "trainer.evaluate(eval_dataset=test_dataset)"
939
+ ]
940
+ },
941
+ {
942
+ "cell_type": "markdown",
943
+ "metadata": {},
944
+ "source": [
945
+ "Исходная задача у нас звучала как \"хотим увидеть топ-95%* тематик, отсортированных по убыванию вероятности\", где под тематиками имелись ввиду категории (физика, биология и так далее)\n",
946
+ "\n",
947
+ "Будем делать следующее:\n",
948
+ "- наша модель выдает логиты категорий\n",
949
+ "- посчитаем с их помощью вероятность категорий, считая их сумму равной 1 (хотя на самом деле категорий может быть несколько)\n",
950
+ "- выведем требуемые топ-95% тематик"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "execution_count": null,
956
+ "metadata": {},
957
+ "outputs": [],
958
+ "source": [
959
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
960
+ " f\"train-{USED_MODEL}/checkpoint-10764\", \n",
961
+ " problem_type=\"multi_label_classification\", \n",
962
+ " num_labels=len(category_to_index),\n",
963
+ " id2label=index_to_category,\n",
964
+ " label2id=category_to_index\n",
965
+ ").to(DEVICE)"
966
+ ]
967
+ },
968
+ {
969
+ "cell_type": "code",
970
+ "execution_count": 12,
971
+ "metadata": {},
972
+ "outputs": [
973
+ {
974
+ "data": {
975
+ "text/plain": [
976
+ "SequenceClassifierOutput(loss=None, logits=tensor([[ 1.9626, -6.2954, -3.8152, -1.4332, -3.2061, -4.2741, -5.5496, -1.9070]],\n",
977
+ " device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
978
+ ]
979
+ },
980
+ "execution_count": 12,
981
+ "metadata": {},
982
+ "output_type": "execute_result"
983
+ }
984
+ ],
985
+ "source": [
986
+ "model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function('Maths is cool $ In our article we prove that maths is the coolest subject at school').items()})"
987
+ ]
988
+ },
989
+ {
990
+ "cell_type": "code",
991
+ "execution_count": 13,
992
+ "metadata": {},
993
+ "outputs": [],
994
+ "source": [
995
+ "@torch.no_grad\n",
996
+ "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
997
+ " text = f'{title} $ {summary}'\n",
998
+ " category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
999
+ " sigmoid = torch.nn.Sigmoid()\n",
1000
+ " category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n",
1001
+ " category_probs /= category_probs.sum()\n",
1002
+ " category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}\n",
1003
+ " for index in range(len(index_to_category)):\n",
1004
+ " category_probs_dict[index_to_category[index]] += float(category_probs[index])\n",
1005
+ " return category_probs_dict"
1006
+ ]
1007
+ },
1008
+ {
1009
+ "cell_type": "code",
1010
+ "execution_count": 16,
1011
+ "metadata": {},
1012
+ "outputs": [],
1013
+ "source": [
1014
+ "def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:\n",
1015
+ " current_p = 0\n",
1016
+ " probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]\n",
1017
+ " current_index = 0\n",
1018
+ " answer = []\n",
1019
+ " while current_p <= target_probability:\n",
1020
+ " current_p += probs_list[current_index][0]\n",
1021
+ " if not print_probabilities:\n",
1022
+ " answer.append(probs_list[current_index][1])\n",
1023
+ " else:\n",
1024
+ " answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')\n",
1025
+ " current_index += 1\n",
1026
+ " if current_index >= len(probs_list):\n",
1027
+ " break\n",
1028
+ " return answer"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "cell_type": "markdown",
1033
+ "metadata": {},
1034
+ "source": [
1035
+ "Сохраняем модель, чтобы потом можно было её использовать в huggingface space"
1036
+ ]
1037
+ },
1038
+ {
1039
+ "cell_type": "code",
1040
+ "execution_count": null,
1041
+ "metadata": {},
1042
+ "outputs": [
1043
+ {
1044
+ "name": "stderr",
1045
+ "output_type": "stream",
1046
+ "text": [
1047
+ "model.safetensors: 100%|██████████| 263M/263M [00:26<00:00, 9.85MB/s] \n"
1048
+ ]
1049
+ },
1050
+ {
1051
+ "data": {
1052
+ "text/plain": [
1053
+ "CommitInfo(commit_url='https://huggingface.co/bumchik2/train-distilbert-base-cased-tags-classification/commit/01aedbd739a5cbd09aad052d81bf874ec2b07f22', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='01aedbd739a5cbd09aad052d81bf874ec2b07f22', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bumchik2/train-distilbert-base-cased-tags-classification', endpoint='https://huggingface.co', repo_type='model', repo_id='bumchik2/train-distilbert-base-cased-tags-classification'), pr_revision=None, pr_num=None)"
1054
+ ]
1055
+ },
1056
+ "execution_count": 19,
1057
+ "metadata": {},
1058
+ "output_type": "execute_result"
1059
+ }
1060
+ ],
1061
+ "source": [
1062
+ "model.push_to_hub(f\"bumchik2/train-{USED_MODEL}-tags-classification\")"
1063
+ ]
1064
+ },
1065
+ {
1066
+ "cell_type": "markdown",
1067
+ "metadata": {},
1068
+ "source": [
1069
+ "Теперь я могу загружать свою модель оттуда"
1070
+ ]
1071
+ },
1072
+ {
1073
+ "cell_type": "code",
1074
+ "execution_count": null,
1075
+ "metadata": {},
1076
+ "outputs": [],
1077
+ "source": [
1078
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
1079
+ " f\"bumchik2/train-{USED_MODEL}-tags-classification\", \n",
1080
+ " problem_type=\"multi_label_classification\", \n",
1081
+ " num_labels=len(category_to_index),\n",
1082
+ " id2label=index_to_category,\n",
1083
+ " label2id=category_to_index\n",
1084
+ ").to(DEVICE)"
1085
+ ]
1086
+ },
1087
+ {
1088
+ "cell_type": "markdown",
1089
+ "metadata": {},
1090
+ "source": [
1091
+ "Протестируем на нескольких реальных примерах:"
1092
+ ]
1093
+ },
1094
+ {
1095
+ "cell_type": "code",
1096
+ "execution_count": null,
1097
+ "metadata": {},
1098
+ "outputs": [
1099
+ {
1100
+ "data": {
1101
+ "text/plain": [
1102
+ "['Quantitative Biology (0.42653992772102356)',\n",
1103
+ " 'Statistics (0.2740147113800049)',\n",
1104
+ " 'Computer Science (0.22953785955905914)',\n",
1105
+ " 'Mathematics (0.032985616475343704)']"
1106
+ ]
1107
+ },
1108
+ "execution_count": 24,
1109
+ "metadata": {},
1110
+ "output_type": "execute_result"
1111
+ }
1112
+ ],
1113
+ "source": [
1114
+ "# правильный ответ Quantitative Biology\n",
1115
+ "get_most_probable_keys(\n",
1116
+ " probs_dict=get_category_probs_dict(\n",
1117
+ " model=model,\n",
1118
+ " title='Simulating cell populations with explicit cell cycle length -- implications to cell cycle dependent tumour therapy',\n",
1119
+ " summary='In this study, we present a stochastic simulation model designed to explicitly incorporate cell cycle length, overcoming limitations associated with classical compartmental models. Our approach employs a delay mechanism to represent the cell cycle, allowing the use of arbitrary distributions for cell cycle lengths. We demonstrate the feasibility of our model by fitting it to experimental data from melanoma cell lines previously studied by Vittadello et al. Notably, our model successfully replicates experimentally observed synchronization phenomena that multi-stage models could not adequately explain. By using a gamma distribution to model cell cycle lengths, we achieved excellent agreement between our simulations and empirical data, while significantly reducing computational complexity and parameter estimation challenges inherent in multi-stage approaches. Our results highlight the importance of explicitly incorporating cell cycle lengths in modeling cell populations, with potential implications for optimizing cell cycle-dependent tumor therapies.'\n",
1120
+ " ),\n",
1121
+ " target_probability=0.95,\n",
1122
+ " print_probabilities=True\n",
1123
+ ")"
1124
+ ]
1125
+ },
1126
+ {
1127
+ "cell_type": "code",
1128
+ "execution_count": null,
1129
+ "metadata": {},
1130
+ "outputs": [
1131
+ {
1132
+ "data": {
1133
+ "text/plain": [
1134
+ "['Physics (0.4614427089691162)',\n",
1135
+ " 'Computer Science (0.4365394413471222)',\n",
1136
+ " 'Statistics (0.03875430300831795)',\n",
1137
+ " 'Mathematics (0.024689726531505585)']"
1138
+ ]
1139
+ },
1140
+ "execution_count": 25,
1141
+ "metadata": {},
1142
+ "output_type": "execute_result"
1143
+ }
1144
+ ],
1145
+ "source": [
1146
+ "# правильный ответ Physics\n",
1147
+ "get_most_probable_keys(\n",
1148
+ " probs_dict=get_category_probs_dict(\n",
1149
+ " model=model,\n",
1150
+ " title='Performance Improvement of LTS Undulators for Synchrotron Light Sources',\n",
1151
+ " summary='The joint expertise of ANL and FNAL has led to the production of Nb3Sn undulator magnets in operation in the ANL Advanced Photon Source (APS). These magnets showed performance reproducibility close to the short sample limit, and a design field increase of 20% at 820A. However, the long training did not allow obtaining the expected 50% increase of the on-axis magnetic field with respect to the ~1 T produced at 450 A current in the ANL NbTi undulator. To address this, 10-pole long undulator prototypes were fabricated, and CTD-101K was replaced as impregnation material with TELENE, an organic olefin-based thermosetting dicyclopentadiene resin produced by RIMTEC Corporation, Japan. Training and magnet retraining after a thermal cycle were nearly eliminated, with only a couple of quenches needed before reaching short sample limit at over 1,100 A. TELENE will enable operation of Nb3Sn undulators much closer to their short sample limit, expanding the energy range and brightness intensity of light sources. TELENE is Co-60 gamma radiation resistant up to 7-8 MGy, and therefore already applicable to impregnate planar, helical and universal devices operating in lower radiation environments than high energy colliders.'\n",
1152
+ " ),\n",
1153
+ " target_probability=0.95,\n",
1154
+ " print_probabilities=True\n",
1155
+ ")"
1156
+ ]
1157
+ },
1158
+ {
1159
+ "cell_type": "code",
1160
+ "execution_count": null,
1161
+ "metadata": {},
1162
+ "outputs": [],
1163
+ "source": []
1164
+ }
1165
+ ],
1166
+ "metadata": {
1167
+ "kernelspec": {
1168
+ "display_name": "Tricks",
1169
+ "language": "python",
1170
+ "name": "python3"
1171
+ },
1172
+ "language_info": {
1173
+ "codemirror_mode": {
1174
+ "name": "ipython",
1175
+ "version": 3
1176
+ },
1177
+ "file_extension": ".py",
1178
+ "mimetype": "text/x-python",
1179
+ "name": "python",
1180
+ "nbconvert_exporter": "python",
1181
+ "pygments_lexer": "ipython3",
1182
+ "version": "3.10.12"
1183
+ }
1184
+ },
1185
+ "nbformat": 4,
1186
+ "nbformat_minor": 2
1187
+ }
notebooks/distilbert_base_cased_main_baseline.ipynb ADDED
@@ -0,0 +1,1127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ }
16
+ ],
17
+ "source": [
18
+ "from transformers import pipeline\n",
19
+ "import json\n",
20
+ "import pandas as pd\n",
21
+ "from sklearn.model_selection import train_test_split\n",
22
+ "from transformers import DistilBertTokenizer\n",
23
+ "from tqdm import tqdm\n",
24
+ "import re\n",
25
+ "from datasets import Dataset\n",
26
+ "from transformers import AutoModelForSequenceClassification\n",
27
+ "import torch\n",
28
+ "import numpy as np\n",
29
+ "from typing import Dict\n",
30
+ "from transformers import AutoModel\n",
31
+ "from typing import List\n",
32
+ "from transformers import TrainingArguments, Trainer\n",
33
+ "from collections import defaultdict"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 11,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "USED_MODEL = \"distilbert-base-cased\""
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": 9,
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "def read_json(json_filename):\n",
52
+ " with open(json_filename, 'r') as f:\n",
53
+ " return json.loads(f.read())\n",
54
+ "\n",
55
+ "\n",
56
+ "def save_json(json_object, json_filename, indent=4):\n",
57
+ " with open(json_filename, 'w') as f:\n",
58
+ " json.dump(json_object, f, separators=(',', ':'), indent=indent)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "markdown",
63
+ "metadata": {},
64
+ "source": [
65
+ "**Данные берем отсюда: https://www.kaggle.com/datasets/neelshah18/arxivdataset**"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": 10,
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "arxiv_data = read_json('arxivData.json')"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 11,
80
+ "metadata": {},
81
+ "outputs": [
82
+ {
83
+ "data": {
84
+ "text/plain": [
85
+ "{'author': \"[{'name': 'Ahmed Osman'}, {'name': 'Wojciech Samek'}]\",\n",
86
+ " 'day': 1,\n",
87
+ " 'id': '1802.00209v1',\n",
88
+ " 'link': \"[{'rel': 'alternate', 'href': 'http://arxiv.org/abs/1802.00209v1', 'type': 'text/html'}, {'rel': 'related', 'href': 'http://arxiv.org/pdf/1802.00209v1', 'type': 'application/pdf', 'title': 'pdf'}]\",\n",
89
+ " 'month': 2,\n",
90
+ " 'summary': 'We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.',\n",
91
+ " 'tag': \"[{'term': 'cs.AI', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CL', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CV', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.NE', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'stat.ML', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}]\",\n",
92
+ " 'title': 'Dual Recurrent Attention Units for Visual Question Answering',\n",
93
+ " 'year': 2018}"
94
+ ]
95
+ },
96
+ "execution_count": 11,
97
+ "metadata": {},
98
+ "output_type": "execute_result"
99
+ }
100
+ ],
101
+ "source": [
102
+ "arxiv_data[0]"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "markdown",
107
+ "metadata": {},
108
+ "source": [
109
+ "**Хотим по названию статьи + abstract выдавать наиболее вероятную тематику статьи, скажем, физика, биология или computer science** "
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": 2,
115
+ "metadata": {},
116
+ "outputs": [
117
+ {
118
+ "name": "stdout",
119
+ "output_type": "stream",
120
+ "text": [
121
+ "155\n"
122
+ ]
123
+ },
124
+ {
125
+ "data": {
126
+ "text/html": [
127
+ "<div>\n",
128
+ "<style scoped>\n",
129
+ " .dataframe tbody tr th:only-of-type {\n",
130
+ " vertical-align: middle;\n",
131
+ " }\n",
132
+ "\n",
133
+ " .dataframe tbody tr th {\n",
134
+ " vertical-align: top;\n",
135
+ " }\n",
136
+ "\n",
137
+ " .dataframe thead th {\n",
138
+ " text-align: right;\n",
139
+ " }\n",
140
+ "</style>\n",
141
+ "<table border=\"1\" class=\"dataframe\">\n",
142
+ " <thead>\n",
143
+ " <tr style=\"text-align: right;\">\n",
144
+ " <th></th>\n",
145
+ " <th>tag</th>\n",
146
+ " <th>topic</th>\n",
147
+ " <th>category</th>\n",
148
+ " </tr>\n",
149
+ " </thead>\n",
150
+ " <tbody>\n",
151
+ " <tr>\n",
152
+ " <th>0</th>\n",
153
+ " <td>cs.AI</td>\n",
154
+ " <td>Artificial Intelligence</td>\n",
155
+ " <td>Computer Science</td>\n",
156
+ " </tr>\n",
157
+ " <tr>\n",
158
+ " <th>1</th>\n",
159
+ " <td>cs.AR</td>\n",
160
+ " <td>Hardware Architecture</td>\n",
161
+ " <td>Computer Science</td>\n",
162
+ " </tr>\n",
163
+ " <tr>\n",
164
+ " <th>2</th>\n",
165
+ " <td>cs.CC</td>\n",
166
+ " <td>Computational Complexity</td>\n",
167
+ " <td>Computer Science</td>\n",
168
+ " </tr>\n",
169
+ " <tr>\n",
170
+ " <th>3</th>\n",
171
+ " <td>cs.CE</td>\n",
172
+ " <td>Computational Engineering, Finance, and Science</td>\n",
173
+ " <td>Computer Science</td>\n",
174
+ " </tr>\n",
175
+ " <tr>\n",
176
+ " <th>4</th>\n",
177
+ " <td>cs.CG</td>\n",
178
+ " <td>Computational Geometry</td>\n",
179
+ " <td>Computer Science</td>\n",
180
+ " </tr>\n",
181
+ " </tbody>\n",
182
+ "</table>\n",
183
+ "</div>"
184
+ ],
185
+ "text/plain": [
186
+ " tag topic category\n",
187
+ "0 cs.AI Artificial Intelligence Computer Science\n",
188
+ "1 cs.AR Hardware Architecture Computer Science\n",
189
+ "2 cs.CC Computational Complexity Computer Science\n",
190
+ "3 cs.CE Computational Engineering, Finance, and Science Computer Science\n",
191
+ "4 cs.CG Computational Geometry Computer Science"
192
+ ]
193
+ },
194
+ "execution_count": 2,
195
+ "metadata": {},
196
+ "output_type": "execute_result"
197
+ }
198
+ ],
199
+ "source": [
200
+ "# Manually prepared dataframe with arxiv topics\n",
201
+ "arxiv_topics_df = pd.read_csv('arxiv_topics.csv')\n",
202
+ "print(len(arxiv_topics_df))\n",
203
+ "arxiv_topics_df.head(5)"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": 3,
209
+ "metadata": {},
210
+ "outputs": [],
211
+ "source": [
212
+ "tag_to_index = {}\n",
213
+ "tag_to_category = {}\n",
214
+ "for i, row in arxiv_topics_df.iterrows():\n",
215
+ " tag_to_index[row['tag']] = i\n",
216
+ " tag_to_category[row['tag']] = row['category']\n",
217
+ "index_to_tag = {value: key for key, value in tag_to_index.items()}"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "markdown",
222
+ "metadata": {},
223
+ "source": [
224
+ "**Готовим данные к обучению**"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": 49,
230
+ "metadata": {},
231
+ "outputs": [
232
+ {
233
+ "name": "stderr",
234
+ "output_type": "stream",
235
+ "text": [
236
+ "100%|██████████| 41000/41000 [00:01<00:00, 33941.59it/s]"
237
+ ]
238
+ },
239
+ {
240
+ "name": "stdout",
241
+ "output_type": "stream",
242
+ "text": [
243
+ "Среднее число категорий в одной статье: 1.3301219512195122\n",
244
+ "Среднее число тегов в одной статье: 1.8489024390243902\n"
245
+ ]
246
+ },
247
+ {
248
+ "name": "stderr",
249
+ "output_type": "stream",
250
+ "text": [
251
+ "\n"
252
+ ]
253
+ }
254
+ ],
255
+ "source": [
256
+ "def is_valid_tag(tag: str) -> bool:\n",
257
+ " return tag in tag_to_index\n",
258
+ "\n",
259
+ "total_categories_count = 0\n",
260
+ "total_tags_count = 0\n",
261
+ "records = []\n",
262
+ "for arxiv_record in tqdm(arxiv_data):\n",
263
+ " record = {\n",
264
+ " 'title': arxiv_record['title'],\n",
265
+ " 'summary': arxiv_record['summary'],\n",
266
+ " 'title_and_summary': arxiv_record['title'] + ' $ ' + arxiv_record['summary'],\n",
267
+ " 'tags': sorted([current_tag['term'] for current_tag in eval(arxiv_record['tag']) if is_valid_tag(current_tag['term'])], key=lambda x: tag_to_index[x])\n",
268
+ " }\n",
269
+ " categories = set(tag_to_category[tag] for tag in record['tags'])\n",
270
+ " total_categories_count += len(categories)\n",
271
+ " total_tags_count += len(record['tags'])\n",
272
+ " record['tags_indices'] = [tag_to_index[tag] for tag in record['tags']]\n",
273
+ " assert len(record['tags']) > 0\n",
274
+ " records.append(record)\n",
275
+ "\n",
276
+ "print(f'Среднее число категорий в одной статье: {total_categories_count / len(arxiv_data)}')\n",
277
+ "print(f'Среднее число тегов в одной статье: {total_tags_count / len(arxiv_data)}')"
278
+ ]
279
+ },
280
+ {
281
+ "cell_type": "markdown",
282
+ "metadata": {},
283
+ "source": [
284
+ "Как видим, перед нами задача мультибинарной классификации.\n",
285
+ "\n",
286
+ "Тегов у одной статьи бывает много, это понятно, но и категорий тоже бывает много. То есть, условно статья может быть посвящена и физике и биологии одновременно.\n",
287
+ "\n",
288
+ "Попробуем обучить модель определять теги - так она потенциально может сохранить в себе больше информации, чем если ее обучить определять категории (которых гораздо меньше)."
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "markdown",
293
+ "metadata": {},
294
+ "source": [
295
+ "**Соединяем title и summary используя символ `$` - он редкий, при этом его знает токенайзер, поэтому не придется с ним дополнительно возиться**"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 50,
301
+ "metadata": {},
302
+ "outputs": [
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "41000\n"
308
+ ]
309
+ },
310
+ {
311
+ "data": {
312
+ "text/html": [
313
+ "<div>\n",
314
+ "<style scoped>\n",
315
+ " .dataframe tbody tr th:only-of-type {\n",
316
+ " vertical-align: middle;\n",
317
+ " }\n",
318
+ "\n",
319
+ " .dataframe tbody tr th {\n",
320
+ " vertical-align: top;\n",
321
+ " }\n",
322
+ "\n",
323
+ " .dataframe thead th {\n",
324
+ " text-align: right;\n",
325
+ " }\n",
326
+ "</style>\n",
327
+ "<table border=\"1\" class=\"dataframe\">\n",
328
+ " <thead>\n",
329
+ " <tr style=\"text-align: right;\">\n",
330
+ " <th></th>\n",
331
+ " <th>title</th>\n",
332
+ " <th>summary</th>\n",
333
+ " <th>title_and_summary</th>\n",
334
+ " <th>tags</th>\n",
335
+ " <th>tags_indices</th>\n",
336
+ " </tr>\n",
337
+ " </thead>\n",
338
+ " <tbody>\n",
339
+ " <tr>\n",
340
+ " <th>0</th>\n",
341
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
342
+ " <td>We propose an architecture for VQA which utili...</td>\n",
343
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
344
+ " <td>[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML]</td>\n",
345
+ " <td>[0, 5, 7, 28, 152]</td>\n",
346
+ " </tr>\n",
347
+ " <tr>\n",
348
+ " <th>1</th>\n",
349
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
350
+ " <td>Recent approaches based on artificial neural n...</td>\n",
351
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
352
+ " <td>[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML]</td>\n",
353
+ " <td>[0, 5, 22, 28, 152]</td>\n",
354
+ " </tr>\n",
355
+ " <tr>\n",
356
+ " <th>2</th>\n",
357
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
358
+ " <td>We introduce the multiresolution recurrent neu...</td>\n",
359
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
360
+ " <td>[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML]</td>\n",
361
+ " <td>[0, 5, 22, 28, 152]</td>\n",
362
+ " </tr>\n",
363
+ " <tr>\n",
364
+ " <th>3</th>\n",
365
+ " <td>Learning what to share between loosely related...</td>\n",
366
+ " <td>Multi-task learning is motivated by the observ...</td>\n",
367
+ " <td>Learning what to share between loosely related...</td>\n",
368
+ " <td>[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML]</td>\n",
369
+ " <td>[0, 5, 22, 28, 152]</td>\n",
370
+ " </tr>\n",
371
+ " <tr>\n",
372
+ " <th>4</th>\n",
373
+ " <td>A Deep Reinforcement Learning Chatbot</td>\n",
374
+ " <td>We present MILABOT: a deep reinforcement learn...</td>\n",
375
+ " <td>A Deep Reinforcement Learning Chatbot $ We pre...</td>\n",
376
+ " <td>[cs.AI, cs.CL, cs.LG, cs.NE, stat.ML]</td>\n",
377
+ " <td>[0, 5, 22, 28, 152]</td>\n",
378
+ " </tr>\n",
379
+ " </tbody>\n",
380
+ "</table>\n",
381
+ "</div>"
382
+ ],
383
+ "text/plain": [
384
+ " title \\\n",
385
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
386
+ "1 Sequential Short-Text Classification with Recu... \n",
387
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
388
+ "3 Learning what to share between loosely related... \n",
389
+ "4 A Deep Reinforcement Learning Chatbot \n",
390
+ "\n",
391
+ " summary \\\n",
392
+ "0 We propose an architecture for VQA which utili... \n",
393
+ "1 Recent approaches based on artificial neural n... \n",
394
+ "2 We introduce the multiresolution recurrent neu... \n",
395
+ "3 Multi-task learning is motivated by the observ... \n",
396
+ "4 We present MILABOT: a deep reinforcement learn... \n",
397
+ "\n",
398
+ " title_and_summary \\\n",
399
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
400
+ "1 Sequential Short-Text Classification with Recu... \n",
401
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
402
+ "3 Learning what to share between loosely related... \n",
403
+ "4 A Deep Reinforcement Learning Chatbot $ We pre... \n",
404
+ "\n",
405
+ " tags tags_indices \n",
406
+ "0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] [0, 5, 7, 28, 152] \n",
407
+ "1 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] \n",
408
+ "2 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] \n",
409
+ "3 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] \n",
410
+ "4 [cs.AI, cs.CL, cs.LG, cs.NE, stat.ML] [0, 5, 22, 28, 152] "
411
+ ]
412
+ },
413
+ "execution_count": 50,
414
+ "metadata": {},
415
+ "output_type": "execute_result"
416
+ }
417
+ ],
418
+ "source": [
419
+ "full_data_df = pd.DataFrame(records)\n",
420
+ "print(len(full_data_df))\n",
421
+ "full_data_df.head(5)"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "markdown",
426
+ "metadata": {},
427
+ "source": [
428
+ "Посмотрим на распределение тегов и категорий в данных"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "code",
433
+ "execution_count": 57,
434
+ "metadata": {},
435
+ "outputs": [
436
+ {
437
+ "name": "stdout",
438
+ "output_type": "stream",
439
+ "text": [
440
+ "defaultdict(<class 'int'>, {'Statistics': 10618, 'Computer Science': 39251, 'Physics': 1208, 'Mathematics': 2263, 'Quantitative Biology': 896, 'Electrical Engineering and Systems Science': 220, 'Quantitative Finance': 66, 'Economics': 13})\n",
441
+ "defaultdict(<class 'int'>, {'cs.AI': 10481, 'cs.CL': 6417, 'cs.CV': 13902, 'cs.NE': 3819, 'stat.ML': 10326, 'cs.LG': 13735, 'physics.soc-ph': 293, 'stat.AP': 360, 'cs.RO': 973, 'cs.SE': 180, 'cs.MA': 268, 'math.OC': 1020, 'cs.IR': 1443, 'cond-mat.dis-nn': 126, 'stat.ME': 458, 'physics.chem-ph': 16, 'cs.DC': 404, 'stat.CO': 260, 'q-bio.NC': 513, 'cs.GT': 318, 'cs.MM': 345, 'cs.CG': 94, 'cs.CR': 411, 'cs.HC': 434, 'cs.GL': 10, 'eess.AS': 89, 'cs.SD': 389, 'math.DS': 49, 'cs.GR': 225, 'math.NA': 172, 'cs.CY': 376, 'physics.data-an': 187, 'math.ST': 336, 'stat.TH': 336, 'cs.IT': 543, 'math.IT': 543, 'quant-ph': 142, 'astro-ph.GA': 6, 'astro-ph.IM': 76, 'cs.SI': 639, 'cs.DB': 327, 'cs.LO': 643, 'nlin.AO': 119, 'cs.PF': 35, 'cs.ET': 85, 'eess.IV': 85, 'cs.AR': 52, 'cs.SY': 270, 'cs.CC': 196, 'q-bio.BM': 30, 'q-bio.QM': 232, 'cs.NI': 137, 'cs.DS': 570, 'cond-mat.stat-mech': 84, 'cs.NA': 253, 'cs.DM': 101, 'eess.SP': 52, 'cs.MS': 66, 'physics.med-ph': 81, 'physics.optics': 60, 'q-fin.CP': 14, 'cs.FL': 50, 'cs.SC': 24, 'q-fin.EC': 5, 'q-fin.TR': 9, 'cond-mat.mes-hall': 14, 'math.PR': 144, 'q-fin.RM': 3, 'nlin.CD': 29, 'cs.CE': 285, 'math.AT': 13, 'stat.OT': 8, 'physics.ao-ph': 19, 'math.SP': 7, 'cs.PL': 128, 'math.AP': 13, 'math.FA': 43, 'gr-qc': 6, 'physics.geo-ph': 14, 'q-bio.TO': 8, 'physics.comp-ph': 34, 'cs.DL': 139, 'math.CO': 33, 'physics.flu-dyn': 3, 'math.MG': 9, 'astro-ph.EP': 4, 'q-bio.CB': 5, 'hep-th': 6, 'math.RA': 11, 'astro-ph.CO': 10, 'cond-mat.mtrl-sci': 12, 'q-fin.ST': 15, 'q-bio.GN': 50, 'hep-ex': 9, 'nlin.CG': 18, 'nlin.PS': 3, 'math.HO': 8, 'q-fin.GN': 13, 'math.LO': 37, 'math.CT': 26, 'q-bio.PE': 84, 'astro-ph.SR': 9, 'q-fin.PM': 12, 'physics.bio-ph': 34, 'math.AG': 21, 'cs.OH': 11, 'math.DG': 17, 'astro-ph.HE': 4, 'econ.EM': 13, 'math.QA': 2, 'q-bio.SC': 3, 'math.GM': 3, 'q-bio.MN': 26, 'math.GT': 5, 'math.AC': 3, 'math.CA': 6, 'cond-mat.str-el': 5, 'math.GN': 4, 'hep-ph': 6, 'cond-mat.supr-con': 4, 'q-bio.OT': 5, 'nucl-th': 2, 'physics.ins-det': 9, 'hep-lat': 3, 'physics.app-ph': 1, 'math.RT': 3, 'math.MP': 4, 'math-ph': 4, 'physics.class-ph': 2, 'q-fin.PR': 1, 'physics.space-ph': 2, 'physics.gen-ph': 1, 'cond-mat.other': 2, 'math.GR': 4, 'nucl-ex': 3, 'cond-mat.quant-gas': 1, 'math.OA': 2, 'physics.hist-ph': 4, 'math.NT': 1, 'cs.OS': 2, 'cond-mat.soft': 2, 'physics.pop-ph': 1, 'math.CV': 1})\n"
442
+ ]
443
+ }
444
+ ],
445
+ "source": [
446
+ "tag_to_count = defaultdict(int)\n",
447
+ "category_to_count = defaultdict(int)\n",
448
+ "for i, row in full_data_df.iterrows():\n",
449
+ " found_categories = set()\n",
450
+ " for tag in row['tags']:\n",
451
+ " tag_to_count[tag] += 1\n",
452
+ " found_categories.add(tag_to_category[tag])\n",
453
+ " for category in found_categories:\n",
454
+ " category_to_count[category] += 1\n",
455
+ "print(category_to_count)\n",
456
+ "print(tag_to_count)"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "markdown",
461
+ "metadata": {},
462
+ "source": [
463
+ "**Как видим, Computer science встречается очень часто. А, например, экономика - совсем редко**\n",
464
+ "\n",
465
+ "**Это по-хорошему нужно учесть, но в рамках данного ноутбука мы это делать не будем**"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": 10,
471
+ "metadata": {},
472
+ "outputs": [],
473
+ "source": [
474
+ "text_data = list(full_data_df['title_and_summary'])\n",
475
+ "tags_indices = list(full_data_df['tags_indices'])"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": 11,
481
+ "metadata": {},
482
+ "outputs": [
483
+ {
484
+ "name": "stdout",
485
+ "output_type": "stream",
486
+ "text": [
487
+ "28700 8200 4100\n"
488
+ ]
489
+ }
490
+ ],
491
+ "source": [
492
+ "X_train_val, X_test, y_train_val, y_test = train_test_split(text_data, tags_indices, test_size=0.1, random_state=42)\n",
493
+ "X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=2/9, random_state=42)\n",
494
+ "print(len(X_train), len(X_val), len(X_test))\n",
495
+ "# Train is 70%, val is 20%, test is 10%"
496
+ ]
497
+ },
498
+ {
499
+ "cell_type": "code",
500
+ "execution_count": 12,
501
+ "metadata": {},
502
+ "outputs": [],
503
+ "source": [
504
+ "tokenizer = DistilBertTokenizer.from_pretrained(USED_MODEL)"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": 13,
510
+ "metadata": {},
511
+ "outputs": [],
512
+ "source": [
513
+ "def tokenize_function(text):\n",
514
+ " return tokenizer(text, padding=\"max_length\", truncation=True)"
515
+ ]
516
+ },
517
+ {
518
+ "cell_type": "code",
519
+ "execution_count": 17,
520
+ "metadata": {},
521
+ "outputs": [
522
+ {
523
+ "name": "stdout",
524
+ "output_type": "stream",
525
+ "text": [
526
+ "Dual Recurrent Attention Units for Visual Question Answering $ We propose an architecture for VQA which utilizes recurrent layers to\n",
527
+ "generate visual and textual attention. The memory characteristic of the\n",
528
+ "proposed recurrent attention units offers a rich joint embedding of visual and\n",
529
+ "textual features and enables the model to reason relations between several\n",
530
+ "parts of the image and question. Our single model outperforms the first place\n",
531
+ "winner on the VQA 1.0 dataset, performs within margin to the current\n",
532
+ "state-of-the-art ensemble model. We also experiment with replacing attention\n",
533
+ "mechanisms in other state-of-the-art models with our implementation and show\n",
534
+ "increased accuracy. In both cases, our recurrent attention mechanism improves\n",
535
+ "performance in tasks requiring sequential or relational reasoning on the VQA\n",
536
+ "dataset.\n"
537
+ ]
538
+ },
539
+ {
540
+ "data": {
541
+ "text/plain": [
542
+ "{'input_ids': [101, 27791, 11336, 21754, 1335, 5208, 2116, 21687, 1111, 12071, 22171, 26018, 1158, 109, 1284, 17794, 1126, 4220, 1111, 159, 4880, 1592, 1134, 24242, 1231, 21754, 8798, 1106, 9509, 5173, 1105, 3087, 4746, 2209, 119, 1109, 2962, 7987, 1104, 1103, 3000, 1231, 21754, 2209, 2338, 3272, 170, 3987, 4091, 9712, 4774, 3408, 1104, 5173, 1105, 3087, 4746, 1956, 1105, 13267, 1103, 2235, 1106, 2255, 4125, 1206, 1317, 2192, 1104, 1103, 3077, 1105, 2304, 119, 3458, 1423, 2235, 1149, 3365, 13199, 1116, 1103, 1148, 1282, 2981, 1113, 1103, 159, 4880, 1592, 122, 119, 121, 2233, 9388, 117, 10383, 1439, 7464, 1106, 1103, 1954, 1352, 118, 1104, 118, 1103, 118, 1893, 9525, 2235, 119, 1284, 1145, 7886, 1114, 5861, 2209, 10748, 1107, 1168, 1352, 118, 1104, 118, 1103, 118, 1893, 3584, 1114, 1412, 7249, 1105, 1437, 2569, 10893, 119, 1130, 1241, 2740, 117, 1412, 1231, 21754, 2209, 6978, 4607, 1116, 2099, 1107, 8249, 8753, 14516, 21967, 1137, 6796, 1348, 14417, 1113, 1103, 159, 4880, 1592, 2233, 9388, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}"
543
+ ]
544
+ },
545
+ "execution_count": 17,
546
+ "metadata": {},
547
+ "output_type": "execute_result"
548
+ }
549
+ ],
550
+ "source": [
551
+ "print(text_data[0])\n",
552
+ "tokenize_function(text_data[0])"
553
+ ]
554
+ },
555
+ {
556
+ "cell_type": "code",
557
+ "execution_count": 18,
558
+ "metadata": {},
559
+ "outputs": [],
560
+ "source": [
561
+ "train_encodings = tokenize_function(X_train)\n",
562
+ "val_encodings = tokenize_function(X_val)\n",
563
+ "test_encodings = tokenize_function(X_test)"
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "code",
568
+ "execution_count": 19,
569
+ "metadata": {},
570
+ "outputs": [
571
+ {
572
+ "name": "stdout",
573
+ "output_type": "stream",
574
+ "text": [
575
+ "<class 'transformers.tokenization_utils_base.BatchEncoding'>\n",
576
+ "['_MutableMapping__marker', '__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__or__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_encodings', '_n_sequences', 'char_to_token', 'char_to_word', 'clear', 'convert_to_tensors', 'copy', 'data', 'encodings', 'fromkeys', 'get', 'is_fast', 'items', 'keys', 'n_sequences', 'pop', 'popitem', 'sequence_ids', 'setdefault', 'to', 'token_to_chars', 'token_to_sequence', 'token_to_word', 'tokens', 'update', 'values', 'word_ids', 'word_to_chars', 'word_to_tokens', 'words']\n",
577
+ "2\n"
578
+ ]
579
+ }
580
+ ],
581
+ "source": [
582
+ "print(type(train_encodings))\n",
583
+ "print(dir(train_encodings))\n",
584
+ "print(len(train_encodings))"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 20,
590
+ "metadata": {},
591
+ "outputs": [],
592
+ "source": [
593
+ "def get_labels(y: List[List[int]]):\n",
594
+ " labels = np.zeros((len(y), len(tag_to_index)))\n",
595
+ " for i in tqdm(range(len(y))):\n",
596
+ " labels[i, y[i]] = 1\n",
597
+ " return labels.tolist()"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": 21,
603
+ "metadata": {},
604
+ "outputs": [
605
+ {
606
+ "name": "stderr",
607
+ "output_type": "stream",
608
+ "text": [
609
+ "100%|██████████| 28700/28700 [00:00<00:00, 388780.42it/s]\n",
610
+ "100%|██████████| 8200/8200 [00:00<00:00, 223262.03it/s]\n",
611
+ "100%|██████████| 4100/4100 [00:00<00:00, 165215.75it/s]\n"
612
+ ]
613
+ }
614
+ ],
615
+ "source": [
616
+ "labels_train = get_labels(y_train)\n",
617
+ "labels_val = get_labels(y_val)\n",
618
+ "labels_test = get_labels(y_test)"
619
+ ]
620
+ },
621
+ {
622
+ "cell_type": "code",
623
+ "execution_count": 22,
624
+ "metadata": {},
625
+ "outputs": [],
626
+ "source": [
627
+ "train_encodings['labels'] = labels_train\n",
628
+ "val_encodings['labels'] = labels_val\n",
629
+ "test_encodings['labels'] = labels_test"
630
+ ]
631
+ },
632
+ {
633
+ "cell_type": "markdown",
634
+ "metadata": {},
635
+ "source": [
636
+ "**Я использовал пример отсюда чтобы понимать, какой нужен формат данных https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb**"
637
+ ]
638
+ },
639
+ {
640
+ "cell_type": "code",
641
+ "execution_count": 23,
642
+ "metadata": {},
643
+ "outputs": [],
644
+ "source": [
645
+ "train_dataset = Dataset.from_dict(train_encodings)\n",
646
+ "val_dataset = Dataset.from_dict(val_encodings)\n",
647
+ "test_dataset = Dataset.from_dict(test_encodings)"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": 24,
653
+ "metadata": {},
654
+ "outputs": [
655
+ {
656
+ "name": "stderr",
657
+ "output_type": "stream",
658
+ "text": [
659
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
660
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
661
+ ]
662
+ }
663
+ ],
664
+ "source": [
665
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
666
+ " USED_MODEL, \n",
667
+ " problem_type=\"multi_label_classification\", \n",
668
+ " num_labels=len(tag_to_index),\n",
669
+ " id2label=index_to_tag,\n",
670
+ " label2id=tag_to_index\n",
671
+ ")"
672
+ ]
673
+ },
674
+ {
675
+ "cell_type": "code",
676
+ "execution_count": 25,
677
+ "metadata": {},
678
+ "outputs": [],
679
+ "source": [
680
+ "batch_size = 8\n",
681
+ "metric_name = \"f1\""
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": null,
687
+ "metadata": {},
688
+ "outputs": [
689
+ {
690
+ "name": "stderr",
691
+ "output_type": "stream",
692
+ "text": [
693
+ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
694
+ " warnings.warn(\n"
695
+ ]
696
+ }
697
+ ],
698
+ "source": [
699
+ "args = TrainingArguments(\n",
700
+ " output_dir=f'train-{USED_MODEL}-baseline',\n",
701
+ " evaluation_strategy=\"epoch\",\n",
702
+ " save_strategy=\"epoch\",\n",
703
+ " learning_rate=2e-5,\n",
704
+ " per_device_train_batch_size=batch_size,\n",
705
+ " per_device_eval_batch_size=batch_size,\n",
706
+ " num_train_epochs=5,\n",
707
+ " weight_decay=0.01,\n",
708
+ " load_best_model_at_end=True,\n",
709
+ " metric_for_best_model=metric_name,\n",
710
+ " push_to_hub=False\n",
711
+ ")"
712
+ ]
713
+ },
714
+ {
715
+ "cell_type": "code",
716
+ "execution_count": 27,
717
+ "metadata": {},
718
+ "outputs": [],
719
+ "source": [
720
+ "from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n",
721
+ "from transformers import EvalPrediction\n",
722
+ "import torch\n",
723
+ " \n",
724
+ "# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n",
725
+ "def multi_label_metrics(predictions, labels, threshold=0.5):\n",
726
+ " # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)\n",
727
+ " sigmoid = torch.nn.Sigmoid()\n",
728
+ " probs = sigmoid(torch.Tensor(predictions))\n",
729
+ " # next, use threshold to turn them into integer predictions\n",
730
+ " y_pred = np.zeros(probs.shape)\n",
731
+ " y_pred[np.where(probs >= threshold)] = 1\n",
732
+ " # finally, compute metrics\n",
733
+ " y_true = labels\n",
734
+ " f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n",
735
+ " roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n",
736
+ " accuracy = accuracy_score(y_true, y_pred)\n",
737
+ " # return as dictionary\n",
738
+ " metrics = {'f1': f1_micro_average,\n",
739
+ " 'roc_auc': roc_auc,\n",
740
+ " 'accuracy': accuracy}\n",
741
+ " return metrics\n",
742
+ "\n",
743
+ "def compute_metrics(p: EvalPrediction):\n",
744
+ " preds = p.predictions[0] if isinstance(p.predictions, \n",
745
+ " tuple) else p.predictions\n",
746
+ " result = multi_label_metrics(\n",
747
+ " predictions=preds, \n",
748
+ " labels=p.label_ids)\n",
749
+ " return result"
750
+ ]
751
+ },
752
+ {
753
+ "cell_type": "code",
754
+ "execution_count": 28,
755
+ "metadata": {},
756
+ "outputs": [],
757
+ "source": [
758
+ "train_dataset.set_format(\"torch\")\n",
759
+ "test_dataset.set_format(\"torch\")"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "code",
764
+ "execution_count": 29,
765
+ "metadata": {},
766
+ "outputs": [
767
+ {
768
+ "name": "stderr",
769
+ "output_type": "stream",
770
+ "text": [
771
+ "/tmp/ipykernel_571129/1751307119.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n",
772
+ " trainer = Trainer(\n"
773
+ ]
774
+ }
775
+ ],
776
+ "source": [
777
+ "trainer = Trainer(\n",
778
+ " model,\n",
779
+ " args,\n",
780
+ " train_dataset=train_dataset,\n",
781
+ " eval_dataset=val_dataset,\n",
782
+ " tokenizer=tokenizer,\n",
783
+ " compute_metrics=compute_metrics\n",
784
+ ")"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": 30,
790
+ "metadata": {},
791
+ "outputs": [
792
+ {
793
+ "data": {
794
+ "text/html": [
795
+ "\n",
796
+ " <div>\n",
797
+ " \n",
798
+ " <progress value='17940' max='17940' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
799
+ " [17940/17940 32:06, Epoch 5/5]\n",
800
+ " </div>\n",
801
+ " <table border=\"1\" class=\"dataframe\">\n",
802
+ " <thead>\n",
803
+ " <tr style=\"text-align: left;\">\n",
804
+ " <th>Epoch</th>\n",
805
+ " <th>Training Loss</th>\n",
806
+ " <th>Validation Loss</th>\n",
807
+ " <th>F1</th>\n",
808
+ " <th>Roc Auc</th>\n",
809
+ " <th>Accuracy</th>\n",
810
+ " </tr>\n",
811
+ " </thead>\n",
812
+ " <tbody>\n",
813
+ " <tr>\n",
814
+ " <td>1</td>\n",
815
+ " <td>0.024000</td>\n",
816
+ " <td>0.022899</td>\n",
817
+ " <td>0.652954</td>\n",
818
+ " <td>0.770167</td>\n",
819
+ " <td>0.410366</td>\n",
820
+ " </tr>\n",
821
+ " <tr>\n",
822
+ " <td>2</td>\n",
823
+ " <td>0.020400</td>\n",
824
+ " <td>0.020730</td>\n",
825
+ " <td>0.673765</td>\n",
826
+ " <td>0.785226</td>\n",
827
+ " <td>0.426829</td>\n",
828
+ " </tr>\n",
829
+ " <tr>\n",
830
+ " <td>3</td>\n",
831
+ " <td>0.017900</td>\n",
832
+ " <td>0.019692</td>\n",
833
+ " <td>0.700292</td>\n",
834
+ " <td>0.812313</td>\n",
835
+ " <td>0.425000</td>\n",
836
+ " </tr>\n",
837
+ " <tr>\n",
838
+ " <td>4</td>\n",
839
+ " <td>0.016100</td>\n",
840
+ " <td>0.019695</td>\n",
841
+ " <td>0.701593</td>\n",
842
+ " <td>0.812366</td>\n",
843
+ " <td>0.433171</td>\n",
844
+ " </tr>\n",
845
+ " <tr>\n",
846
+ " <td>5</td>\n",
847
+ " <td>0.014800</td>\n",
848
+ " <td>0.019767</td>\n",
849
+ " <td>0.701193</td>\n",
850
+ " <td>0.812710</td>\n",
851
+ " <td>0.431707</td>\n",
852
+ " </tr>\n",
853
+ " </tbody>\n",
854
+ "</table><p>"
855
+ ],
856
+ "text/plain": [
857
+ "<IPython.core.display.HTML object>"
858
+ ]
859
+ },
860
+ "metadata": {},
861
+ "output_type": "display_data"
862
+ },
863
+ {
864
+ "data": {
865
+ "text/plain": [
866
+ "TrainOutput(global_step=17940, training_loss=0.02238395190159214, metrics={'train_runtime': 1927.2238, 'train_samples_per_second': 74.459, 'train_steps_per_second': 9.309, 'total_flos': 1.906093867776e+16, 'train_loss': 0.02238395190159214, 'epoch': 5.0})"
867
+ ]
868
+ },
869
+ "execution_count": 30,
870
+ "metadata": {},
871
+ "output_type": "execute_result"
872
+ }
873
+ ],
874
+ "source": [
875
+ "trainer.train()"
876
+ ]
877
+ },
878
+ {
879
+ "cell_type": "code",
880
+ "execution_count": 32,
881
+ "metadata": {},
882
+ "outputs": [
883
+ {
884
+ "data": {
885
+ "text/plain": [
886
+ "{'eval_loss': 0.019695421680808067,\n",
887
+ " 'eval_f1': 0.7015928686248721,\n",
888
+ " 'eval_roc_auc': 0.8123655228058703,\n",
889
+ " 'eval_accuracy': 0.43317073170731707,\n",
890
+ " 'eval_runtime': 34.8656,\n",
891
+ " 'eval_samples_per_second': 235.189,\n",
892
+ " 'eval_steps_per_second': 29.399,\n",
893
+ " 'epoch': 5.0}"
894
+ ]
895
+ },
896
+ "execution_count": 32,
897
+ "metadata": {},
898
+ "output_type": "execute_result"
899
+ }
900
+ ],
901
+ "source": [
902
+ "trainer.evaluate(eval_dataset=val_dataset)"
903
+ ]
904
+ },
905
+ {
906
+ "cell_type": "code",
907
+ "execution_count": 33,
908
+ "metadata": {},
909
+ "outputs": [
910
+ {
911
+ "data": {
912
+ "text/plain": [
913
+ "{'eval_loss': 0.019682902842760086,\n",
914
+ " 'eval_f1': 0.6966158423205653,\n",
915
+ " 'eval_roc_auc': 0.8081637343174538,\n",
916
+ " 'eval_accuracy': 0.4370731707317073,\n",
917
+ " 'eval_runtime': 16.5771,\n",
918
+ " 'eval_samples_per_second': 247.329,\n",
919
+ " 'eval_steps_per_second': 30.946,\n",
920
+ " 'epoch': 5.0}"
921
+ ]
922
+ },
923
+ "execution_count": 33,
924
+ "metadata": {},
925
+ "output_type": "execute_result"
926
+ }
927
+ ],
928
+ "source": [
929
+ "trainer.evaluate(eval_dataset=test_dataset)"
930
+ ]
931
+ },
932
+ {
933
+ "cell_type": "markdown",
934
+ "metadata": {},
935
+ "source": [
936
+ "Исходная задача у нас звучала как \"хотим увидеть топ-95%* тематик, отсортированных по убыванию вероятности\", где под тематиками имелись ввиду категории (физика, биология и так далее)\n",
937
+ "\n",
938
+ "Будем делать следующее:\n",
939
+ "- наша модель выдает логиты тегов\n",
940
+ "- посчитаем с их помощью вероятность каждого тега, считая сумму вероятностей равной 1\n",
941
+ "- посчитаем вероятность категории как сумму вероятностей тегов\n",
942
+ "- выведем требуемые топ-95% тематик"
943
+ ]
944
+ },
945
+ {
946
+ "cell_type": "code",
947
+ "execution_count": 5,
948
+ "metadata": {},
949
+ "outputs": [],
950
+ "source": [
951
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
952
+ " \"train_distilbert-base-cased/checkpoint-17940\", \n",
953
+ " problem_type=\"multi_label_classification\", \n",
954
+ " num_labels=len(tag_to_index),\n",
955
+ " id2label=index_to_tag,\n",
956
+ " label2id=tag_to_index\n",
957
+ ").to(torch.device('cuda'))"
958
+ ]
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "execution_count": 10,
963
+ "metadata": {},
964
+ "outputs": [
965
+ {
966
+ "data": {
967
+ "text/plain": [
968
+ "SequenceClassifierOutput(loss=None, logits=tensor([[-1.3623, -5.3834, -3.3988, -3.4555, -3.7096, -4.5285, -5.1323, -2.3077,\n",
969
+ " -3.6645, -4.6847, -4.2481, -5.0417, -3.5121, -2.7808, -5.9767, -4.8864,\n",
970
+ " -5.6730, -4.6838, -3.8588, -5.2819, -3.9295, -2.7704, 0.4331, -4.5505,\n",
971
+ " -5.2648, -4.9248, -4.2074, -3.4895, -3.2717, -5.2713, -5.7536, -7.2749,\n",
972
+ " -4.8728, -5.2606, -4.5935, -4.7103, -5.4628, -5.4589, -5.3678, -3.5648,\n",
973
+ " -5.1455, -8.8455, -9.1583, -6.4358, -4.7737, -4.7821, -8.9264, -5.8790,\n",
974
+ " -4.7536, -5.4549, -5.3879, -6.1918, -4.1667, -7.1828, -7.3235, -5.4470,\n",
975
+ " -4.6688, -4.7201, -6.2949, -7.5401, -6.6242, -6.1022, -5.5325, -3.1546,\n",
976
+ " -9.4200, -5.2060, -5.3880, -6.8743, -3.3176, -7.2654, -7.4301, -3.0929,\n",
977
+ " -3.2351, -9.0408, -5.4315, -6.3230, -9.5853, -5.7075, -3.6443, -5.5524,\n",
978
+ " -6.0723, -6.0414, -7.3201, -3.9738, -5.5964, -4.0455, -5.2017, -5.8061,\n",
979
+ " -7.8401, -7.5268, -7.4576, -4.4483, -6.4790, -5.9085, -6.8822, -5.4498,\n",
980
+ " -6.7494, -6.1449, -5.9297, -6.4985, -5.0379, -4.9914, -5.5201, -7.9075,\n",
981
+ " -8.7653, -6.6116, -6.6643, -9.3863, -4.9038, -7.6509, -9.0117, -9.1193,\n",
982
+ " -5.3166, -5.4046, -8.3876, -4.9028, -3.5257, -8.9734, -6.1487, -8.1408,\n",
983
+ " -5.3014, -6.5494, -6.8383, -4.8011, -5.2831, -8.7708, -7.5039, -5.3957,\n",
984
+ " -7.3326, -3.6551, -4.9892, -5.9366, -5.2093, -5.2362, -5.0462, -6.5469,\n",
985
+ " -4.9182, -4.4108, -7.1632, -5.9481, -5.3291, -6.4517, -5.6950, -8.7276,\n",
986
+ " -5.7762, -8.9848, -7.3795, -5.4210, -5.6845, -2.9447, -3.6166, -3.6258,\n",
987
+ " -1.4417, -5.6568, -3.5869]], device='cuda:0',\n",
988
+ " grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
989
+ ]
990
+ },
991
+ "execution_count": 10,
992
+ "metadata": {},
993
+ "output_type": "execute_result"
994
+ }
995
+ ],
996
+ "source": [
997
+ "model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function('Maths is cool $ In our article we prove that maths is the coolest subject at school').items()})"
998
+ ]
999
+ },
1000
+ {
1001
+ "cell_type": "code",
1002
+ "execution_count": 19,
1003
+ "metadata": {},
1004
+ "outputs": [],
1005
+ "source": [
1006
+ "@torch.no_grad\n",
1007
+ "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
1008
+ " text = f'{title} $ {summary}'\n",
1009
+ " tags_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
1010
+ " sigmoid = torch.nn.Sigmoid()\n",
1011
+ " tags_probs = sigmoid(tags_logits.squeeze().cpu()).numpy()\n",
1012
+ " tags_probs /= tags_probs.sum()\n",
1013
+ " category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}\n",
1014
+ " for index in range(len(index_to_tag)):\n",
1015
+ " category_probs_dict[tag_to_category[index_to_tag[index]]] += float(tags_probs[index])\n",
1016
+ " return category_probs_dict"
1017
+ ]
1018
+ },
1019
+ {
1020
+ "cell_type": "code",
1021
+ "execution_count": 16,
1022
+ "metadata": {},
1023
+ "outputs": [],
1024
+ "source": [
1025
+ "def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:\n",
1026
+ " current_p = 0\n",
1027
+ " probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]\n",
1028
+ " current_index = 0\n",
1029
+ " answer = []\n",
1030
+ " while current_p <= target_probability:\n",
1031
+ " current_p += probs_list[current_index][0]\n",
1032
+ " if not print_probabilities:\n",
1033
+ " answer.append(probs_list[current_index][1])\n",
1034
+ " else:\n",
1035
+ " answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')\n",
1036
+ " current_index += 1\n",
1037
+ " if current_index >= len(probs_list):\n",
1038
+ " break\n",
1039
+ " return answer"
1040
+ ]
1041
+ },
1042
+ {
1043
+ "cell_type": "markdown",
1044
+ "metadata": {},
1045
+ "source": [
1046
+ "Теперь нужно как-то сохранить модель, чтобы потом можно было её использовать в huggingface space"
1047
+ ]
1048
+ },
1049
+ {
1050
+ "cell_type": "code",
1051
+ "execution_count": 6,
1052
+ "metadata": {},
1053
+ "outputs": [
1054
+ {
1055
+ "name": "stderr",
1056
+ "output_type": "stream",
1057
+ "text": [
1058
+ "model.safetensors: 100%|██████████| 264M/264M [00:31<00:00, 8.47MB/s] \n"
1059
+ ]
1060
+ },
1061
+ {
1062
+ "data": {
1063
+ "text/plain": [
1064
+ "CommitInfo(commit_url='https://huggingface.co/bumchik2/train_distilbert-base-cased-tags-classification-simple/commit/98a87d7c96e0647dd557a9d47be03ddd30e0c964', commit_message='Upload DistilBertForSequenceClassification', commit_description='', oid='98a87d7c96e0647dd557a9d47be03ddd30e0c964', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bumchik2/train_distilbert-base-cased-tags-classification-simple', endpoint='https://huggingface.co', repo_type='model', repo_id='bumchik2/train_distilbert-base-cased-tags-classification-simple'), pr_revision=None, pr_num=None)"
1065
+ ]
1066
+ },
1067
+ "execution_count": 6,
1068
+ "metadata": {},
1069
+ "output_type": "execute_result"
1070
+ }
1071
+ ],
1072
+ "source": [
1073
+ "model.push_to_hub(\"bumchik2/train_distilbert-base-cased-tags-classification-simple\")"
1074
+ ]
1075
+ },
1076
+ {
1077
+ "cell_type": "markdown",
1078
+ "metadata": {},
1079
+ "source": [
1080
+ "Теперь я смогу загружать свою модель оттуда"
1081
+ ]
1082
+ },
1083
+ {
1084
+ "cell_type": "code",
1085
+ "execution_count": 7,
1086
+ "metadata": {},
1087
+ "outputs": [],
1088
+ "source": [
1089
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
1090
+ " \"bumchik2/train_distilbert-base-cased-tags-classification-simple\", \n",
1091
+ " problem_type=\"multi_label_classification\", \n",
1092
+ " num_labels=len(tag_to_index),\n",
1093
+ " id2label=index_to_tag,\n",
1094
+ " label2id=tag_to_index\n",
1095
+ ").to(torch.device('cuda'))"
1096
+ ]
1097
+ },
1098
+ {
1099
+ "cell_type": "code",
1100
+ "execution_count": null,
1101
+ "metadata": {},
1102
+ "outputs": [],
1103
+ "source": []
1104
+ }
1105
+ ],
1106
+ "metadata": {
1107
+ "kernelspec": {
1108
+ "display_name": "Tricks",
1109
+ "language": "python",
1110
+ "name": "python3"
1111
+ },
1112
+ "language_info": {
1113
+ "codemirror_mode": {
1114
+ "name": "ipython",
1115
+ "version": 3
1116
+ },
1117
+ "file_extension": ".py",
1118
+ "mimetype": "text/x-python",
1119
+ "name": "python",
1120
+ "nbconvert_exporter": "python",
1121
+ "pygments_lexer": "ipython3",
1122
+ "version": "3.10.12"
1123
+ }
1124
+ },
1125
+ "nbformat": 4,
1126
+ "nbformat_minor": 2
1127
+ }
notebooks/distilroberta_base_main.ipynb ADDED
@@ -0,0 +1,1180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [
8
+ {
9
+ "name": "stderr",
10
+ "output_type": "stream",
11
+ "text": [
12
+ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
13
+ " from .autonotebook import tqdm as notebook_tqdm\n"
14
+ ]
15
+ },
16
+ {
17
+ "name": "stdout",
18
+ "output_type": "stream",
19
+ "text": [
20
+ "4.48.3\n"
21
+ ]
22
+ },
23
+ {
24
+ "data": {
25
+ "text/plain": [
26
+ "device(type='cuda')"
27
+ ]
28
+ },
29
+ "execution_count": 1,
30
+ "metadata": {},
31
+ "output_type": "execute_result"
32
+ }
33
+ ],
34
+ "source": [
35
+ "from transformers import pipeline\n",
36
+ "import json\n",
37
+ "import pandas as pd\n",
38
+ "from sklearn.model_selection import train_test_split\n",
39
+ "from transformers import RobertaTokenizer\n",
40
+ "from tqdm import tqdm\n",
41
+ "import re\n",
42
+ "from datasets import Dataset\n",
43
+ "from transformers import AutoModelForSequenceClassification\n",
44
+ "import torch\n",
45
+ "import numpy as np\n",
46
+ "from typing import Dict\n",
47
+ "from transformers import AutoModel\n",
48
+ "from torch.nn import BCEWithLogitsLoss\n",
49
+ "from typing import List\n",
50
+ "from transformers import TrainingArguments, Trainer\n",
51
+ "from collections import defaultdict\n",
52
+ "\n",
53
+ "from transformers import __version__ as transformers_version\n",
54
+ "print(transformers_version)\n",
55
+ "\n",
56
+ "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
57
+ "DEVICE"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": 2,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "USED_MODEL = \"distilroberta-base\""
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": 3,
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "def read_json(json_filename):\n",
76
+ " with open(json_filename, 'r') as f:\n",
77
+ " return json.loads(f.read())\n",
78
+ "\n",
79
+ "\n",
80
+ "def save_json(json_object, json_filename, indent=4):\n",
81
+ " with open(json_filename, 'w') as f:\n",
82
+ " json.dump(json_object, f, separators=(',', ':'), indent=indent)"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "markdown",
87
+ "metadata": {},
88
+ "source": [
89
+ "**Данные берем отсюда: https://www.kaggle.com/datasets/neelshah18/arxivdataset**"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 4,
95
+ "metadata": {},
96
+ "outputs": [],
97
+ "source": [
98
+ "arxiv_data = read_json('arxivData.json')"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 5,
104
+ "metadata": {},
105
+ "outputs": [
106
+ {
107
+ "data": {
108
+ "text/plain": [
109
+ "{'author': \"[{'name': 'Ahmed Osman'}, {'name': 'Wojciech Samek'}]\",\n",
110
+ " 'day': 1,\n",
111
+ " 'id': '1802.00209v1',\n",
112
+ " 'link': \"[{'rel': 'alternate', 'href': 'http://arxiv.org/abs/1802.00209v1', 'type': 'text/html'}, {'rel': 'related', 'href': 'http://arxiv.org/pdf/1802.00209v1', 'type': 'application/pdf', 'title': 'pdf'}]\",\n",
113
+ " 'month': 2,\n",
114
+ " 'summary': 'We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.',\n",
115
+ " 'tag': \"[{'term': 'cs.AI', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CL', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CV', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.NE', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'stat.ML', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}]\",\n",
116
+ " 'title': 'Dual Recurrent Attention Units for Visual Question Answering',\n",
117
+ " 'year': 2018}"
118
+ ]
119
+ },
120
+ "execution_count": 5,
121
+ "metadata": {},
122
+ "output_type": "execute_result"
123
+ }
124
+ ],
125
+ "source": [
126
+ "arxiv_data[0]"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "metadata": {},
132
+ "source": [
133
+ "**Хотим по названию статьи + abstract выдавать наиболее вероятную тематику статьи, скажем, физика, биология или computer science** "
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "execution_count": 6,
139
+ "metadata": {},
140
+ "outputs": [
141
+ {
142
+ "name": "stdout",
143
+ "output_type": "stream",
144
+ "text": [
145
+ "155\n"
146
+ ]
147
+ },
148
+ {
149
+ "data": {
150
+ "text/html": [
151
+ "<div>\n",
152
+ "<style scoped>\n",
153
+ " .dataframe tbody tr th:only-of-type {\n",
154
+ " vertical-align: middle;\n",
155
+ " }\n",
156
+ "\n",
157
+ " .dataframe tbody tr th {\n",
158
+ " vertical-align: top;\n",
159
+ " }\n",
160
+ "\n",
161
+ " .dataframe thead th {\n",
162
+ " text-align: right;\n",
163
+ " }\n",
164
+ "</style>\n",
165
+ "<table border=\"1\" class=\"dataframe\">\n",
166
+ " <thead>\n",
167
+ " <tr style=\"text-align: right;\">\n",
168
+ " <th></th>\n",
169
+ " <th>tag</th>\n",
170
+ " <th>topic</th>\n",
171
+ " <th>category</th>\n",
172
+ " </tr>\n",
173
+ " </thead>\n",
174
+ " <tbody>\n",
175
+ " <tr>\n",
176
+ " <th>0</th>\n",
177
+ " <td>cs.AI</td>\n",
178
+ " <td>Artificial Intelligence</td>\n",
179
+ " <td>Computer Science</td>\n",
180
+ " </tr>\n",
181
+ " <tr>\n",
182
+ " <th>1</th>\n",
183
+ " <td>cs.AR</td>\n",
184
+ " <td>Hardware Architecture</td>\n",
185
+ " <td>Computer Science</td>\n",
186
+ " </tr>\n",
187
+ " <tr>\n",
188
+ " <th>2</th>\n",
189
+ " <td>cs.CC</td>\n",
190
+ " <td>Computational Complexity</td>\n",
191
+ " <td>Computer Science</td>\n",
192
+ " </tr>\n",
193
+ " <tr>\n",
194
+ " <th>3</th>\n",
195
+ " <td>cs.CE</td>\n",
196
+ " <td>Computational Engineering, Finance, and Science</td>\n",
197
+ " <td>Computer Science</td>\n",
198
+ " </tr>\n",
199
+ " <tr>\n",
200
+ " <th>4</th>\n",
201
+ " <td>cs.CG</td>\n",
202
+ " <td>Computational Geometry</td>\n",
203
+ " <td>Computer Science</td>\n",
204
+ " </tr>\n",
205
+ " </tbody>\n",
206
+ "</table>\n",
207
+ "</div>"
208
+ ],
209
+ "text/plain": [
210
+ " tag topic category\n",
211
+ "0 cs.AI Artificial Intelligence Computer Science\n",
212
+ "1 cs.AR Hardware Architecture Computer Science\n",
213
+ "2 cs.CC Computational Complexity Computer Science\n",
214
+ "3 cs.CE Computational Engineering, Finance, and Science Computer Science\n",
215
+ "4 cs.CG Computational Geometry Computer Science"
216
+ ]
217
+ },
218
+ "execution_count": 6,
219
+ "metadata": {},
220
+ "output_type": "execute_result"
221
+ }
222
+ ],
223
+ "source": [
224
+ "arxiv_topics_df = pd.read_csv('arxiv_topics.csv')\n",
225
+ "print(len(arxiv_topics_df))\n",
226
+ "arxiv_topics_df.head(5)"
227
+ ]
228
+ },
229
+ {
230
+ "cell_type": "code",
231
+ "execution_count": 7,
232
+ "metadata": {},
233
+ "outputs": [],
234
+ "source": [
235
+ "category_to_index = {}\n",
236
+ "tag_to_category = {}\n",
237
+ "current_index = 0\n",
238
+ "for i, row in arxiv_topics_df.iterrows():\n",
239
+ " category = row['category']\n",
240
+ " if category not in category_to_index:\n",
241
+ " category_to_index[category] = current_index\n",
242
+ " current_index += 1\n",
243
+ " tag_to_category[row['tag']] = row['category']\n",
244
+ "index_to_category = {value: key for key, value in category_to_index.items()}"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "markdown",
249
+ "metadata": {},
250
+ "source": [
251
+ "**Готовим данные к обучению**"
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": 8,
257
+ "metadata": {},
258
+ "outputs": [
259
+ {
260
+ "name": "stderr",
261
+ "output_type": "stream",
262
+ "text": [
263
+ " 0%| | 0/41000 [00:00<?, ?it/s]"
264
+ ]
265
+ },
266
+ {
267
+ "name": "stderr",
268
+ "output_type": "stream",
269
+ "text": [
270
+ "100%|██████████| 41000/41000 [00:01<00:00, 40017.30it/s]"
271
+ ]
272
+ },
273
+ {
274
+ "name": "stdout",
275
+ "output_type": "stream",
276
+ "text": [
277
+ "Среднее число категорий в одной статье: 1.3301219512195122\n",
278
+ "Среднее число тегов в одной статье: 1.8489024390243902\n"
279
+ ]
280
+ },
281
+ {
282
+ "name": "stderr",
283
+ "output_type": "stream",
284
+ "text": [
285
+ "\n"
286
+ ]
287
+ }
288
+ ],
289
+ "source": [
290
+ "def is_valid_tag(tag: str) -> bool:\n",
291
+ " return tag in tag_to_category\n",
292
+ "\n",
293
+ "total_categories_count = 0\n",
294
+ "total_tags_count = 0\n",
295
+ "records = []\n",
296
+ "for arxiv_record in tqdm(arxiv_data):\n",
297
+ " record = {\n",
298
+ " 'title': arxiv_record['title'],\n",
299
+ " 'summary': arxiv_record['summary'],\n",
300
+ " 'title_and_summary': arxiv_record['title'] + ' $ ' + arxiv_record['summary'],\n",
301
+ " 'tags': [current_tag['term'] for current_tag in eval(arxiv_record['tag']) if is_valid_tag(current_tag['term'])]\n",
302
+ " }\n",
303
+ " categories = set(tag_to_category[tag] for tag in record['tags'])\n",
304
+ " total_categories_count += len(categories)\n",
305
+ " total_tags_count += len(record['tags'])\n",
306
+ " record['categories_indices'] = list(set([category_to_index[tag_to_category[tag]] for tag in record['tags']]))\n",
307
+ " assert len(record['tags']) > 0\n",
308
+ " records.append(record)\n",
309
+ "\n",
310
+ "print(f'Среднее число категорий в одной статье: {total_categories_count / len(arxiv_data)}')\n",
311
+ "print(f'Среднее число тегов в одной статье: {total_tags_count / len(arxiv_data)}')"
312
+ ]
313
+ },
314
+ {
315
+ "cell_type": "markdown",
316
+ "metadata": {},
317
+ "source": [
318
+ "Как видим, перед нами задача мультибинарной классификации.\n",
319
+ "\n",
320
+ "Тегов у одной статьи бывает много, это понятно, но и категорий тоже бывает много. То есть, условно статья может быть посвящена и физике и биологии одновременно.\n",
321
+ "\n",
322
+ "Попробуем обучить модель определять теги - так она потенциально может сохранить в себе больше информации, чем если ее обучить определять категории (которых гораздо меньше)."
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "markdown",
327
+ "metadata": {},
328
+ "source": [
329
+ "**Соединяем title и summary используя символ `$` - он редкий, при этом его знает токенайзер, поэтому не придется с ним дополнительно возиться**"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "code",
334
+ "execution_count": 9,
335
+ "metadata": {},
336
+ "outputs": [
337
+ {
338
+ "name": "stdout",
339
+ "output_type": "stream",
340
+ "text": [
341
+ "41000\n"
342
+ ]
343
+ },
344
+ {
345
+ "data": {
346
+ "text/html": [
347
+ "<div>\n",
348
+ "<style scoped>\n",
349
+ " .dataframe tbody tr th:only-of-type {\n",
350
+ " vertical-align: middle;\n",
351
+ " }\n",
352
+ "\n",
353
+ " .dataframe tbody tr th {\n",
354
+ " vertical-align: top;\n",
355
+ " }\n",
356
+ "\n",
357
+ " .dataframe thead th {\n",
358
+ " text-align: right;\n",
359
+ " }\n",
360
+ "</style>\n",
361
+ "<table border=\"1\" class=\"dataframe\">\n",
362
+ " <thead>\n",
363
+ " <tr style=\"text-align: right;\">\n",
364
+ " <th></th>\n",
365
+ " <th>title</th>\n",
366
+ " <th>summary</th>\n",
367
+ " <th>title_and_summary</th>\n",
368
+ " <th>tags</th>\n",
369
+ " <th>categories_indices</th>\n",
370
+ " </tr>\n",
371
+ " </thead>\n",
372
+ " <tbody>\n",
373
+ " <tr>\n",
374
+ " <th>0</th>\n",
375
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
376
+ " <td>We propose an architecture for VQA which utili...</td>\n",
377
+ " <td>Dual Recurrent Attention Units for Visual Ques...</td>\n",
378
+ " <td>[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML]</td>\n",
379
+ " <td>[0, 7]</td>\n",
380
+ " </tr>\n",
381
+ " <tr>\n",
382
+ " <th>1</th>\n",
383
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
384
+ " <td>Recent approaches based on artificial neural n...</td>\n",
385
+ " <td>Sequential Short-Text Classification with Recu...</td>\n",
386
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
387
+ " <td>[0, 7]</td>\n",
388
+ " </tr>\n",
389
+ " <tr>\n",
390
+ " <th>2</th>\n",
391
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
392
+ " <td>We introduce the multiresolution recurrent neu...</td>\n",
393
+ " <td>Multiresolution Recurrent Neural Networks: An ...</td>\n",
394
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
395
+ " <td>[0, 7]</td>\n",
396
+ " </tr>\n",
397
+ " <tr>\n",
398
+ " <th>3</th>\n",
399
+ " <td>Learning what to share between loosely related...</td>\n",
400
+ " <td>Multi-task learning is motivated by the observ...</td>\n",
401
+ " <td>Learning what to share between loosely related...</td>\n",
402
+ " <td>[stat.ML, cs.AI, cs.CL, cs.LG, cs.NE]</td>\n",
403
+ " <td>[0, 7]</td>\n",
404
+ " </tr>\n",
405
+ " <tr>\n",
406
+ " <th>4</th>\n",
407
+ " <td>A Deep Reinforcement Learning Chatbot</td>\n",
408
+ " <td>We present MILABOT: a deep reinforcement learn...</td>\n",
409
+ " <td>A Deep Reinforcement Learning Chatbot $ We pre...</td>\n",
410
+ " <td>[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML]</td>\n",
411
+ " <td>[0, 7]</td>\n",
412
+ " </tr>\n",
413
+ " </tbody>\n",
414
+ "</table>\n",
415
+ "</div>"
416
+ ],
417
+ "text/plain": [
418
+ " title \\\n",
419
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
420
+ "1 Sequential Short-Text Classification with Recu... \n",
421
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
422
+ "3 Learning what to share between loosely related... \n",
423
+ "4 A Deep Reinforcement Learning Chatbot \n",
424
+ "\n",
425
+ " summary \\\n",
426
+ "0 We propose an architecture for VQA which utili... \n",
427
+ "1 Recent approaches based on artificial neural n... \n",
428
+ "2 We introduce the multiresolution recurrent neu... \n",
429
+ "3 Multi-task learning is motivated by the observ... \n",
430
+ "4 We present MILABOT: a deep reinforcement learn... \n",
431
+ "\n",
432
+ " title_and_summary \\\n",
433
+ "0 Dual Recurrent Attention Units for Visual Ques... \n",
434
+ "1 Sequential Short-Text Classification with Recu... \n",
435
+ "2 Multiresolution Recurrent Neural Networks: An ... \n",
436
+ "3 Learning what to share between loosely related... \n",
437
+ "4 A Deep Reinforcement Learning Chatbot $ We pre... \n",
438
+ "\n",
439
+ " tags categories_indices \n",
440
+ "0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] [0, 7] \n",
441
+ "1 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n",
442
+ "2 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n",
443
+ "3 [stat.ML, cs.AI, cs.CL, cs.LG, cs.NE] [0, 7] \n",
444
+ "4 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] "
445
+ ]
446
+ },
447
+ "execution_count": 9,
448
+ "metadata": {},
449
+ "output_type": "execute_result"
450
+ }
451
+ ],
452
+ "source": [
453
+ "full_data_df = pd.DataFrame(records)\n",
454
+ "print(len(full_data_df))\n",
455
+ "full_data_df.head(5)"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "markdown",
460
+ "metadata": {},
461
+ "source": [
462
+ "**Как видим, Computer science встречается очень часто. А, например, экономика - совсем редко. Значит при обучении экономике логично давать больше вес**"
463
+ ]
464
+ },
465
+ {
466
+ "cell_type": "code",
467
+ "execution_count": 10,
468
+ "metadata": {},
469
+ "outputs": [],
470
+ "source": [
471
+ "text_data = list(full_data_df['title_and_summary'])\n",
472
+ "categories_indices = list(full_data_df['categories_indices'])"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "code",
477
+ "execution_count": 11,
478
+ "metadata": {},
479
+ "outputs": [
480
+ {
481
+ "name": "stdout",
482
+ "output_type": "stream",
483
+ "text": [
484
+ "28700 8200 4100\n"
485
+ ]
486
+ }
487
+ ],
488
+ "source": [
489
+ "X_train_val, X_test, y_train_val, y_test = train_test_split(text_data, categories_indices, test_size=0.1, random_state=42)\n",
490
+ "X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=2/9, random_state=42)\n",
491
+ "print(len(X_train), len(X_val), len(X_test))\n",
492
+ "# Train is 70%, val is 20%, test is 10%"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "markdown",
497
+ "metadata": {},
498
+ "source": [
499
+ "Посмотрим на распределение категорий в тренировочной выборке"
500
+ ]
501
+ },
502
+ {
503
+ "cell_type": "code",
504
+ "execution_count": 12,
505
+ "metadata": {},
506
+ "outputs": [
507
+ {
508
+ "name": "stdout",
509
+ "output_type": "stream",
510
+ "text": [
511
+ "{0: 27475, 3: 1591, 7: 7417, 5: 623, 2: 152, 4: 840, 6: 43, 1: 9}\n"
512
+ ]
513
+ }
514
+ ],
515
+ "source": [
516
+ "category_to_count = defaultdict(int)\n",
517
+ "for row in y_train:\n",
518
+ " for category in row:\n",
519
+ " category_to_count[category] += 1\n",
520
+ "category_to_count = dict(category_to_count)\n",
521
+ "print(category_to_count)"
522
+ ]
523
+ },
524
+ {
525
+ "cell_type": "code",
526
+ "execution_count": 8,
527
+ "metadata": {},
528
+ "outputs": [],
529
+ "source": [
530
+ "tokenizer = RobertaTokenizer.from_pretrained(USED_MODEL)\n",
531
+ "def tokenize_function(text):\n",
532
+ " return tokenizer(text, padding=\"max_length\", truncation=True)"
533
+ ]
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "execution_count": 17,
538
+ "metadata": {},
539
+ "outputs": [
540
+ {
541
+ "name": "stdout",
542
+ "output_type": "stream",
543
+ "text": [
544
+ "<class 'transformers.tokenization_utils_base.BatchEncoding'>\n",
545
+ "['_MutableMapping__marker', '__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__or__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_encodings', '_n_sequences', 'char_to_token', 'char_to_word', 'clear', 'convert_to_tensors', 'copy', 'data', 'encodings', 'fromkeys', 'get', 'is_fast', 'items', 'keys', 'n_sequences', 'pop', 'popitem', 'sequence_ids', 'setdefault', 'to', 'token_to_chars', 'token_to_sequence', 'token_to_word', 'tokens', 'update', 'values', 'word_ids', 'word_to_chars', 'word_to_tokens', 'words']\n",
546
+ "2\n"
547
+ ]
548
+ }
549
+ ],
550
+ "source": [
551
+ "train_encodings = tokenize_function(X_train)\n",
552
+ "val_encodings = tokenize_function(X_val)\n",
553
+ "test_encodings = tokenize_function(X_test)\n",
554
+ "\n",
555
+ "print(type(train_encodings))\n",
556
+ "print(dir(train_encodings))\n",
557
+ "print(len(train_encodings))"
558
+ ]
559
+ },
560
+ {
561
+ "cell_type": "code",
562
+ "execution_count": 18,
563
+ "metadata": {},
564
+ "outputs": [
565
+ {
566
+ "name": "stderr",
567
+ "output_type": "stream",
568
+ "text": [
569
+ "100%|██████████| 28700/28700 [00:00<00:00, 562525.53it/s]\n",
570
+ "100%|██████████| 8200/8200 [00:00<00:00, 548336.22it/s]\n",
571
+ "100%|██████████| 4100/4100 [00:00<00:00, 538421.57it/s]\n"
572
+ ]
573
+ }
574
+ ],
575
+ "source": [
576
+ "def get_labels(y: List[List[int]]):\n",
577
+ " labels = np.zeros((len(y), len(category_to_index)))\n",
578
+ " for i in tqdm(range(len(y))):\n",
579
+ " labels[i, y[i]] = 1\n",
580
+ " return labels.tolist()\n",
581
+ "\n",
582
+ "labels_train = get_labels(y_train)\n",
583
+ "labels_val = get_labels(y_val)\n",
584
+ "labels_test = get_labels(y_test)"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 19,
590
+ "metadata": {},
591
+ "outputs": [],
592
+ "source": [
593
+ "train_encodings['labels'] = labels_train\n",
594
+ "val_encodings['labels'] = labels_val\n",
595
+ "test_encodings['labels'] = labels_test"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "markdown",
600
+ "metadata": {},
601
+ "source": [
602
+ "**Я использовал пример отсюда чтобы понимать, какой нужен формат данных https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb**"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": 20,
608
+ "metadata": {},
609
+ "outputs": [],
610
+ "source": [
611
+ "train_dataset = Dataset.from_dict(train_encodings)\n",
612
+ "val_dataset = Dataset.from_dict(val_encodings)\n",
613
+ "test_dataset = Dataset.from_dict(test_encodings)\n",
614
+ "\n",
615
+ "train_dataset.set_format(\"torch\")\n",
616
+ "val_dataset.set_format(\"torch\")\n",
617
+ "test_dataset.set_format(\"torch\")"
618
+ ]
619
+ },
620
+ {
621
+ "cell_type": "code",
622
+ "execution_count": 21,
623
+ "metadata": {},
624
+ "outputs": [
625
+ {
626
+ "name": "stderr",
627
+ "output_type": "stream",
628
+ "text": [
629
+ "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']\n",
630
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
631
+ ]
632
+ }
633
+ ],
634
+ "source": [
635
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
636
+ " USED_MODEL, \n",
637
+ " problem_type=\"multi_label_classification\", \n",
638
+ " num_labels=len(category_to_index),\n",
639
+ " id2label=index_to_category,\n",
640
+ " label2id=category_to_index\n",
641
+ ")"
642
+ ]
643
+ },
644
+ {
645
+ "cell_type": "code",
646
+ "execution_count": 22,
647
+ "metadata": {},
648
+ "outputs": [],
649
+ "source": [
650
+ "batch_size = 8\n",
651
+ "metric_name = \"f1\""
652
+ ]
653
+ },
654
+ {
655
+ "cell_type": "code",
656
+ "execution_count": 23,
657
+ "metadata": {},
658
+ "outputs": [
659
+ {
660
+ "name": "stderr",
661
+ "output_type": "stream",
662
+ "text": [
663
+ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/transformers/training_args.py:1611: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
664
+ " warnings.warn(\n"
665
+ ]
666
+ }
667
+ ],
668
+ "source": [
669
+ "args = TrainingArguments(\n",
670
+ " output_dir=f'train-{USED_MODEL}',\n",
671
+ " evaluation_strategy=\"epoch\",\n",
672
+ " save_strategy=\"epoch\",\n",
673
+ " learning_rate=2e-5,\n",
674
+ " per_device_train_batch_size=batch_size,\n",
675
+ " per_device_eval_batch_size=batch_size,\n",
676
+ " num_train_epochs=5,\n",
677
+ " weight_decay=0.01,\n",
678
+ " load_best_model_at_end=True,\n",
679
+ " metric_for_best_model=metric_name,\n",
680
+ " push_to_hub=False\n",
681
+ ")"
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": 24,
687
+ "metadata": {},
688
+ "outputs": [],
689
+ "source": [
690
+ "from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n",
691
+ "from transformers import EvalPrediction\n",
692
+ "import torch\n",
693
+ " \n",
694
+ "# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n",
695
+ "def multi_label_metrics(predictions, labels, threshold=0.5):\n",
696
+ " sigmoid = torch.nn.Sigmoid()\n",
697
+ " probs = sigmoid(torch.Tensor(predictions))\n",
698
+ " y_pred = np.zeros(probs.shape)\n",
699
+ " y_pred[np.where(probs >= threshold)] = 1\n",
700
+ " y_true = labels\n",
701
+ " f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n",
702
+ " roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n",
703
+ " accuracy = accuracy_score(y_true, y_pred)\n",
704
+ " metrics = {'f1': f1_micro_average,\n",
705
+ " 'roc_auc': roc_auc,\n",
706
+ " 'accuracy': accuracy}\n",
707
+ " return metrics\n",
708
+ "\n",
709
+ "def compute_metrics(p: EvalPrediction):\n",
710
+ " preds = p.predictions[0] if isinstance(p.predictions, \n",
711
+ " tuple) else p.predictions\n",
712
+ " result = multi_label_metrics(\n",
713
+ " predictions=preds, \n",
714
+ " labels=p.label_ids)\n",
715
+ " return result"
716
+ ]
717
+ },
718
+ {
719
+ "cell_type": "code",
720
+ "execution_count": 25,
721
+ "metadata": {},
722
+ "outputs": [
723
+ {
724
+ "data": {
725
+ "text/plain": [
726
+ "tensor([1.3057e-01, 3.9861e+02, 2.3602e+01, 2.2549e+00, 4.2708e+00, 5.7584e+00,\n",
727
+ " 8.3430e+01, 4.8369e-01], device='cuda:0')"
728
+ ]
729
+ },
730
+ "execution_count": 25,
731
+ "metadata": {},
732
+ "output_type": "execute_result"
733
+ }
734
+ ],
735
+ "source": [
736
+ "pos_weight=torch.tensor([\n",
737
+ " len(y_train) / category_to_count[i] / len(category_to_count) for i in range(len(category_to_count))\n",
738
+ "]).to(DEVICE)\n",
739
+ "compute_loss_func_ = BCEWithLogitsLoss(pos_weight=pos_weight)\n",
740
+ "\n",
741
+ "# Example of custom trainer is taken from https://medium.com/deeplearningmadeeasy/how-to-use-a-custom-loss-with-hugging-face-fc9a1f91b39b\n",
742
+ "class CustomTrainer(Trainer):\n",
743
+ " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n",
744
+ " labels = inputs.pop(\"labels\")\n",
745
+ " outputs = model(**inputs)\n",
746
+ " logits = outputs.logits\n",
747
+ " loss = compute_loss_func_(logits, labels)\n",
748
+ " return (loss, outputs) if return_outputs else loss\n",
749
+ "\n",
750
+ "pos_weight"
751
+ ]
752
+ },
753
+ {
754
+ "cell_type": "code",
755
+ "execution_count": 26,
756
+ "metadata": {},
757
+ "outputs": [
758
+ {
759
+ "name": "stderr",
760
+ "output_type": "stream",
761
+ "text": [
762
+ "/tmp/ipykernel_758780/1711637572.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `CustomTrainer.__init__`. Use `processing_class` instead.\n",
763
+ " trainer = CustomTrainer(\n"
764
+ ]
765
+ }
766
+ ],
767
+ "source": [
768
+ "trainer = CustomTrainer(\n",
769
+ " model,\n",
770
+ " args,\n",
771
+ " train_dataset=train_dataset,\n",
772
+ " eval_dataset=val_dataset,\n",
773
+ " tokenizer=tokenizer,\n",
774
+ " compute_metrics=compute_metrics\n",
775
+ ")"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "code",
780
+ "execution_count": 27,
781
+ "metadata": {},
782
+ "outputs": [
783
+ {
784
+ "data": {
785
+ "text/html": [
786
+ "\n",
787
+ " <div>\n",
788
+ " \n",
789
+ " <progress value='17940' max='17940' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
790
+ " [17940/17940 31:14, Epoch 5/5]\n",
791
+ " </div>\n",
792
+ " <table border=\"1\" class=\"dataframe\">\n",
793
+ " <thead>\n",
794
+ " <tr style=\"text-align: left;\">\n",
795
+ " <th>Epoch</th>\n",
796
+ " <th>Training Loss</th>\n",
797
+ " <th>Validation Loss</th>\n",
798
+ " <th>F1</th>\n",
799
+ " <th>Roc Auc</th>\n",
800
+ " <th>Accuracy</th>\n",
801
+ " </tr>\n",
802
+ " </thead>\n",
803
+ " <tbody>\n",
804
+ " <tr>\n",
805
+ " <td>1</td>\n",
806
+ " <td>0.488900</td>\n",
807
+ " <td>0.342928</td>\n",
808
+ " <td>0.828636</td>\n",
809
+ " <td>0.869794</td>\n",
810
+ " <td>0.684146</td>\n",
811
+ " </tr>\n",
812
+ " <tr>\n",
813
+ " <td>2</td>\n",
814
+ " <td>0.238000</td>\n",
815
+ " <td>0.350722</td>\n",
816
+ " <td>0.835063</td>\n",
817
+ " <td>0.873791</td>\n",
818
+ " <td>0.688537</td>\n",
819
+ " </tr>\n",
820
+ " <tr>\n",
821
+ " <td>3</td>\n",
822
+ " <td>0.300600</td>\n",
823
+ " <td>0.338232</td>\n",
824
+ " <td>0.835975</td>\n",
825
+ " <td>0.882930</td>\n",
826
+ " <td>0.684268</td>\n",
827
+ " </tr>\n",
828
+ " <tr>\n",
829
+ " <td>4</td>\n",
830
+ " <td>0.138400</td>\n",
831
+ " <td>0.370107</td>\n",
832
+ " <td>0.850701</td>\n",
833
+ " <td>0.891974</td>\n",
834
+ " <td>0.706829</td>\n",
835
+ " </tr>\n",
836
+ " <tr>\n",
837
+ " <td>5</td>\n",
838
+ " <td>0.285600</td>\n",
839
+ " <td>0.383904</td>\n",
840
+ " <td>0.851726</td>\n",
841
+ " <td>0.895529</td>\n",
842
+ " <td>0.703659</td>\n",
843
+ " </tr>\n",
844
+ " </tbody>\n",
845
+ "</table><p>"
846
+ ],
847
+ "text/plain": [
848
+ "<IPython.core.display.HTML object>"
849
+ ]
850
+ },
851
+ "metadata": {},
852
+ "output_type": "display_data"
853
+ },
854
+ {
855
+ "data": {
856
+ "text/plain": [
857
+ "TrainOutput(global_step=17940, training_loss=0.2603886620256813, metrics={'train_runtime': 1875.3299, 'train_samples_per_second': 76.52, 'train_steps_per_second': 9.566, 'total_flos': 1.9011105705984e+16, 'train_loss': 0.2603886620256813, 'epoch': 5.0})"
858
+ ]
859
+ },
860
+ "execution_count": 27,
861
+ "metadata": {},
862
+ "output_type": "execute_result"
863
+ }
864
+ ],
865
+ "source": [
866
+ "trainer.train()"
867
+ ]
868
+ },
869
+ {
870
+ "cell_type": "code",
871
+ "execution_count": 28,
872
+ "metadata": {},
873
+ "outputs": [
874
+ {
875
+ "data": {
876
+ "text/html": [
877
+ "\n",
878
+ " <div>\n",
879
+ " \n",
880
+ " <progress value='1' max='1025' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
881
+ " [ 1/1025 : < :]\n",
882
+ " </div>\n",
883
+ " "
884
+ ],
885
+ "text/plain": [
886
+ "<IPython.core.display.HTML object>"
887
+ ]
888
+ },
889
+ "metadata": {},
890
+ "output_type": "display_data"
891
+ },
892
+ {
893
+ "data": {
894
+ "text/plain": [
895
+ "{'eval_loss': 0.3839036524295807,\n",
896
+ " 'eval_f1': 0.8517256276100418,\n",
897
+ " 'eval_roc_auc': 0.8955286914409087,\n",
898
+ " 'eval_accuracy': 0.7036585365853658,\n",
899
+ " 'eval_runtime': 29.6699,\n",
900
+ " 'eval_samples_per_second': 276.374,\n",
901
+ " 'eval_steps_per_second': 34.547,\n",
902
+ " 'epoch': 5.0}"
903
+ ]
904
+ },
905
+ "execution_count": 28,
906
+ "metadata": {},
907
+ "output_type": "execute_result"
908
+ }
909
+ ],
910
+ "source": [
911
+ "trainer.evaluate(eval_dataset=val_dataset)"
912
+ ]
913
+ },
914
+ {
915
+ "cell_type": "code",
916
+ "execution_count": 29,
917
+ "metadata": {},
918
+ "outputs": [
919
+ {
920
+ "data": {
921
+ "text/plain": [
922
+ "{'eval_loss': 0.4102073907852173,\n",
923
+ " 'eval_f1': 0.8496072917676719,\n",
924
+ " 'eval_roc_auc': 0.8946315908026856,\n",
925
+ " 'eval_accuracy': 0.7048780487804878,\n",
926
+ " 'eval_runtime': 15.9341,\n",
927
+ " 'eval_samples_per_second': 257.309,\n",
928
+ " 'eval_steps_per_second': 32.195,\n",
929
+ " 'epoch': 5.0}"
930
+ ]
931
+ },
932
+ "execution_count": 29,
933
+ "metadata": {},
934
+ "output_type": "execute_result"
935
+ }
936
+ ],
937
+ "source": [
938
+ "trainer.evaluate(eval_dataset=test_dataset)"
939
+ ]
940
+ },
941
+ {
942
+ "cell_type": "markdown",
943
+ "metadata": {},
944
+ "source": [
945
+ "Исходная задача у нас звучала как \"хотим увидеть топ-95%* тематик, отсортированных по убыванию вероятности\", где под тематиками имелись ввиду категории (физика, биология и так далее)\n",
946
+ "\n",
947
+ "Будем делать следующее:\n",
948
+ "- наша модель выдает логиты категорий\n",
949
+ "- посчитаем с их помощью вероятность категорий, считая их сумму равной 1 (хотя на самом деле категорий может быть несколько)\n",
950
+ "- выведем требуемые топ-95% тематик"
951
+ ]
952
+ },
953
+ {
954
+ "cell_type": "code",
955
+ "execution_count": 9,
956
+ "metadata": {},
957
+ "outputs": [],
958
+ "source": [
959
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
960
+ " f\"train-{USED_MODEL}/checkpoint-10764\", \n",
961
+ " problem_type=\"multi_label_classification\", \n",
962
+ " num_labels=len(category_to_index),\n",
963
+ " id2label=index_to_category,\n",
964
+ " label2id=category_to_index\n",
965
+ ").to(DEVICE)"
966
+ ]
967
+ },
968
+ {
969
+ "cell_type": "code",
970
+ "execution_count": 10,
971
+ "metadata": {},
972
+ "outputs": [
973
+ {
974
+ "data": {
975
+ "text/plain": [
976
+ "SequenceClassifierOutput(loss=None, logits=tensor([[ 4.5600, -8.8512, -6.1677, -3.4470, -4.3587, -4.2807, -7.7941, -5.2795]],\n",
977
+ " device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)"
978
+ ]
979
+ },
980
+ "execution_count": 10,
981
+ "metadata": {},
982
+ "output_type": "execute_result"
983
+ }
984
+ ],
985
+ "source": [
986
+ "model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function('Maths is cool $ In our article we prove that maths is the coolest subject at school').items()})"
987
+ ]
988
+ },
989
+ {
990
+ "cell_type": "code",
991
+ "execution_count": 11,
992
+ "metadata": {},
993
+ "outputs": [],
994
+ "source": [
995
+ "@torch.no_grad\n",
996
+ "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n",
997
+ " text = f'{title} $ {summary}'\n",
998
+ " category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n",
999
+ " sigmoid = torch.nn.Sigmoid()\n",
1000
+ " category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n",
1001
+ " category_probs /= category_probs.sum()\n",
1002
+ " category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}\n",
1003
+ " for index in range(len(index_to_category)):\n",
1004
+ " category_probs_dict[index_to_category[index]] += float(category_probs[index])\n",
1005
+ " return category_probs_dict"
1006
+ ]
1007
+ },
1008
+ {
1009
+ "cell_type": "code",
1010
+ "execution_count": 12,
1011
+ "metadata": {},
1012
+ "outputs": [],
1013
+ "source": [
1014
+ "def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:\n",
1015
+ " current_p = 0\n",
1016
+ " probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]\n",
1017
+ " current_index = 0\n",
1018
+ " answer = []\n",
1019
+ " while current_p <= target_probability:\n",
1020
+ " current_p += probs_list[current_index][0]\n",
1021
+ " if not print_probabilities:\n",
1022
+ " answer.append(probs_list[current_index][1])\n",
1023
+ " else:\n",
1024
+ " answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')\n",
1025
+ " current_index += 1\n",
1026
+ " if current_index >= len(probs_list):\n",
1027
+ " break\n",
1028
+ " return answer"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "cell_type": "markdown",
1033
+ "metadata": {},
1034
+ "source": [
1035
+ "Сохраняем модель, чтобы потом можно было её использовать в huggingface space"
1036
+ ]
1037
+ },
1038
+ {
1039
+ "cell_type": "code",
1040
+ "execution_count": 13,
1041
+ "metadata": {},
1042
+ "outputs": [
1043
+ {
1044
+ "name": "stderr",
1045
+ "output_type": "stream",
1046
+ "text": [
1047
+ "model.safetensors: 100%|██████████| 329M/329M [00:32<00:00, 10.1MB/s] \n"
1048
+ ]
1049
+ },
1050
+ {
1051
+ "data": {
1052
+ "text/plain": [
1053
+ "CommitInfo(commit_url='https://huggingface.co/bumchik2/train-distilroberta-base-tags-classification/commit/4494249c82b4ad67f59f5a3b8ae3b49e51eb9425', commit_message='Upload RobertaForSequenceClassification', commit_description='', oid='4494249c82b4ad67f59f5a3b8ae3b49e51eb9425', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bumchik2/train-distilroberta-base-tags-classification', endpoint='https://huggingface.co', repo_type='model', repo_id='bumchik2/train-distilroberta-base-tags-classification'), pr_revision=None, pr_num=None)"
1054
+ ]
1055
+ },
1056
+ "execution_count": 13,
1057
+ "metadata": {},
1058
+ "output_type": "execute_result"
1059
+ }
1060
+ ],
1061
+ "source": [
1062
+ "model.push_to_hub(f\"bumchik2/train-{USED_MODEL}-tags-classification\")"
1063
+ ]
1064
+ },
1065
+ {
1066
+ "cell_type": "markdown",
1067
+ "metadata": {},
1068
+ "source": [
1069
+ "Теперь я могу загружать свою модель оттуда"
1070
+ ]
1071
+ },
1072
+ {
1073
+ "cell_type": "code",
1074
+ "execution_count": 14,
1075
+ "metadata": {},
1076
+ "outputs": [],
1077
+ "source": [
1078
+ "model = AutoModelForSequenceClassification.from_pretrained(\n",
1079
+ " f\"bumchik2/train-{USED_MODEL}-tags-classification\", \n",
1080
+ " problem_type=\"multi_label_classification\", \n",
1081
+ " num_labels=len(category_to_index),\n",
1082
+ " id2label=index_to_category,\n",
1083
+ " label2id=category_to_index\n",
1084
+ ").to(DEVICE)"
1085
+ ]
1086
+ },
1087
+ {
1088
+ "cell_type": "code",
1089
+ "execution_count": 15,
1090
+ "metadata": {},
1091
+ "outputs": [
1092
+ {
1093
+ "data": {
1094
+ "text/plain": [
1095
+ "['Quantitative Biology (0.42838579416275024)',\n",
1096
+ " 'Statistics (0.3568098247051239)',\n",
1097
+ " 'Computer Science (0.09878081828355789)',\n",
1098
+ " 'Physics (0.07676041126251221)']"
1099
+ ]
1100
+ },
1101
+ "execution_count": 15,
1102
+ "metadata": {},
1103
+ "output_type": "execute_result"
1104
+ }
1105
+ ],
1106
+ "source": [
1107
+ "# правильный ответ Quantitative Biology\n",
1108
+ "get_most_probable_keys(\n",
1109
+ " probs_dict=get_category_probs_dict(\n",
1110
+ " model=model,\n",
1111
+ " title='Simulating cell populations with explicit cell cycle length -- implications to cell cycle dependent tumour therapy',\n",
1112
+ " summary='In this study, we present a stochastic simulation model designed to explicitly incorporate cell cycle length, overcoming limitations associated with classical compartmental models. Our approach employs a delay mechanism to represent the cell cycle, allowing the use of arbitrary distributions for cell cycle lengths. We demonstrate the feasibility of our model by fitting it to experimental data from melanoma cell lines previously studied by Vittadello et al. Notably, our model successfully replicates experimentally observed synchronization phenomena that multi-stage models could not adequately explain. By using a gamma distribution to model cell cycle lengths, we achieved excellent agreement between our simulations and empirical data, while significantly reducing computational complexity and parameter estimation challenges inherent in multi-stage approaches. Our results highlight the importance of explicitly incorporating cell cycle lengths in modeling cell populations, with potential implications for optimizing cell cycle-dependent tumor therapies.'\n",
1113
+ " ),\n",
1114
+ " target_probability=0.95,\n",
1115
+ " print_probabilities=True\n",
1116
+ ")"
1117
+ ]
1118
+ },
1119
+ {
1120
+ "cell_type": "code",
1121
+ "execution_count": 16,
1122
+ "metadata": {},
1123
+ "outputs": [
1124
+ {
1125
+ "data": {
1126
+ "text/plain": [
1127
+ "['Physics (0.4553513824939728)',\n",
1128
+ " 'Computer Science (0.43614745140075684)',\n",
1129
+ " 'Electrical Engineering and Systems Science (0.04562709107995033)',\n",
1130
+ " 'Statistics (0.02984526939690113)']"
1131
+ ]
1132
+ },
1133
+ "execution_count": 16,
1134
+ "metadata": {},
1135
+ "output_type": "execute_result"
1136
+ }
1137
+ ],
1138
+ "source": [
1139
+ "# правильный ответ Physics\n",
1140
+ "get_most_probable_keys(\n",
1141
+ " probs_dict=get_category_probs_dict(\n",
1142
+ " model=model,\n",
1143
+ " title='Performance Improvement of LTS Undulators for Synchrotron Light Sources',\n",
1144
+ " summary='The joint expertise of ANL and FNAL has led to the production of Nb3Sn undulator magnets in operation in the ANL Advanced Photon Source (APS). These magnets showed performance reproducibility close to the short sample limit, and a design field increase of 20% at 820A. However, the long training did not allow obtaining the expected 50% increase of the on-axis magnetic field with respect to the ~1 T produced at 450 A current in the ANL NbTi undulator. To address this, 10-pole long undulator prototypes were fabricated, and CTD-101K was replaced as impregnation material with TELENE, an organic olefin-based thermosetting dicyclopentadiene resin produced by RIMTEC Corporation, Japan. Training and magnet retraining after a thermal cycle were nearly eliminated, with only a couple of quenches needed before reaching short sample limit at over 1,100 A. TELENE will enable operation of Nb3Sn undulators much closer to their short sample limit, expanding the energy range and brightness intensity of light sources. TELENE is Co-60 gamma radiation resistant up to 7-8 MGy, and therefore already applicable to impregnate planar, helical and universal devices operating in lower radiation environments than high energy colliders.'\n",
1145
+ " ),\n",
1146
+ " target_probability=0.95,\n",
1147
+ " print_probabilities=True\n",
1148
+ ")"
1149
+ ]
1150
+ },
1151
+ {
1152
+ "cell_type": "code",
1153
+ "execution_count": null,
1154
+ "metadata": {},
1155
+ "outputs": [],
1156
+ "source": []
1157
+ }
1158
+ ],
1159
+ "metadata": {
1160
+ "kernelspec": {
1161
+ "display_name": "Tricks",
1162
+ "language": "python",
1163
+ "name": "python3"
1164
+ },
1165
+ "language_info": {
1166
+ "codemirror_mode": {
1167
+ "name": "ipython",
1168
+ "version": 3
1169
+ },
1170
+ "file_extension": ".py",
1171
+ "mimetype": "text/x-python",
1172
+ "name": "python",
1173
+ "nbconvert_exporter": "python",
1174
+ "pygments_lexer": "ipython3",
1175
+ "version": "3.10.12"
1176
+ }
1177
+ },
1178
+ "nbformat": 4,
1179
+ "nbformat_minor": 2
1180
+ }