{ "cells": [ { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.48.3\n" ] }, { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from transformers import pipeline\n", "import json\n", "import pandas as pd\n", "from sklearn.model_selection import train_test_split\n", "from transformers import AlbertTokenizer\n", "from tqdm import tqdm\n", "import re\n", "from datasets import Dataset\n", "from transformers import AutoModelForSequenceClassification\n", "import torch\n", "import numpy as np\n", "from typing import Dict\n", "from transformers import AutoModel\n", "from torch.nn import BCEWithLogitsLoss\n", "from typing import List\n", "from transformers import TrainingArguments, Trainer\n", "from collections import defaultdict\n", "\n", "from transformers import __version__ as transformers_version\n", "print(transformers_version)\n", "\n", "DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "DEVICE" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "USED_MODEL = \"albert-base-v2\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def read_json(json_filename):\n", " with open(json_filename, 'r') as f:\n", " return json.loads(f.read())\n", "\n", "\n", "def save_json(json_object, json_filename, indent=4):\n", " with open(json_filename, 'w') as f:\n", " json.dump(json_object, f, separators=(',', ':'), indent=indent)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Данные берем отсюда: https://www.kaggle.com/datasets/neelshah18/arxivdataset**" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "arxiv_data = read_json('arxivData.json')" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'author': \"[{'name': 'Ahmed Osman'}, {'name': 'Wojciech Samek'}]\",\n", " 'day': 1,\n", " 'id': '1802.00209v1',\n", " 'link': \"[{'rel': 'alternate', 'href': 'http://arxiv.org/abs/1802.00209v1', 'type': 'text/html'}, {'rel': 'related', 'href': 'http://arxiv.org/pdf/1802.00209v1', 'type': 'application/pdf', 'title': 'pdf'}]\",\n", " 'month': 2,\n", " 'summary': 'We propose an architecture for VQA which utilizes recurrent layers to\\ngenerate visual and textual attention. The memory characteristic of the\\nproposed recurrent attention units offers a rich joint embedding of visual and\\ntextual features and enables the model to reason relations between several\\nparts of the image and question. Our single model outperforms the first place\\nwinner on the VQA 1.0 dataset, performs within margin to the current\\nstate-of-the-art ensemble model. We also experiment with replacing attention\\nmechanisms in other state-of-the-art models with our implementation and show\\nincreased accuracy. In both cases, our recurrent attention mechanism improves\\nperformance in tasks requiring sequential or relational reasoning on the VQA\\ndataset.',\n", " 'tag': \"[{'term': 'cs.AI', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CL', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.CV', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'cs.NE', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}, {'term': 'stat.ML', 'scheme': 'http://arxiv.org/schemas/atom', 'label': None}]\",\n", " 'title': 'Dual Recurrent Attention Units for Visual Question Answering',\n", " 'year': 2018}" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arxiv_data[0]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Хотим по названию статьи + abstract выдавать наиболее вероятную тематику статьи, скажем, физика, биология или computer science** " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "155\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
tagtopiccategory
0cs.AIArtificial IntelligenceComputer Science
1cs.ARHardware ArchitectureComputer Science
2cs.CCComputational ComplexityComputer Science
3cs.CEComputational Engineering, Finance, and ScienceComputer Science
4cs.CGComputational GeometryComputer Science
\n", "
" ], "text/plain": [ " tag topic category\n", "0 cs.AI Artificial Intelligence Computer Science\n", "1 cs.AR Hardware Architecture Computer Science\n", "2 cs.CC Computational Complexity Computer Science\n", "3 cs.CE Computational Engineering, Finance, and Science Computer Science\n", "4 cs.CG Computational Geometry Computer Science" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "arxiv_topics_df = pd.read_csv('arxiv_topics.csv')\n", "print(len(arxiv_topics_df))\n", "arxiv_topics_df.head(5)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "category_to_index = {}\n", "tag_to_category = {}\n", "current_index = 0\n", "for i, row in arxiv_topics_df.iterrows():\n", " category = row['category']\n", " if category not in category_to_index:\n", " category_to_index[category] = current_index\n", " current_index += 1\n", " tag_to_category[row['tag']] = row['category']\n", "index_to_category = {value: key for key, value in category_to_index.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Готовим данные к обучению**" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 0%| | 0/41000 [00:00 bool:\n", " return tag in tag_to_category\n", "\n", "total_categories_count = 0\n", "total_tags_count = 0\n", "records = []\n", "for arxiv_record in tqdm(arxiv_data):\n", " record = {\n", " 'title': arxiv_record['title'],\n", " 'summary': arxiv_record['summary'],\n", " 'title_and_summary': arxiv_record['title'] + ' $ ' + arxiv_record['summary'],\n", " 'tags': [current_tag['term'] for current_tag in eval(arxiv_record['tag']) if is_valid_tag(current_tag['term'])]\n", " }\n", " categories = set(tag_to_category[tag] for tag in record['tags'])\n", " total_categories_count += len(categories)\n", " total_tags_count += len(record['tags'])\n", " record['categories_indices'] = list(set([category_to_index[tag_to_category[tag]] for tag in record['tags']]))\n", " assert len(record['tags']) > 0\n", " records.append(record)\n", "\n", "print(f'Среднее число категорий в одной статье: {total_categories_count / len(arxiv_data)}')\n", "print(f'Среднее число тегов в одной статье: {total_tags_count / len(arxiv_data)}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Как видим, перед нами задача мультибинарной классификации.\n", "\n", "Тегов у одной статьи бывает много, это понятно, но и категорий тоже бывает много. То есть, условно статья может быть посвящена и физике и биологии одновременно.\n", "\n", "Попробуем обучить модель определять теги - так она потенциально может сохранить в себе больше информации, чем если ее обучить определять категории (которых гораздо меньше)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Соединяем title и summary используя символ `$` - он редкий, при этом его знает токенайзер, поэтому не придется с ним дополнительно возиться**" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "41000\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
titlesummarytitle_and_summarytagscategories_indices
0Dual Recurrent Attention Units for Visual Ques...We propose an architecture for VQA which utili...Dual Recurrent Attention Units for Visual Ques...[cs.AI, cs.CL, cs.CV, cs.NE, stat.ML][0, 7]
1Sequential Short-Text Classification with Recu...Recent approaches based on artificial neural n...Sequential Short-Text Classification with Recu...[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML][0, 7]
2Multiresolution Recurrent Neural Networks: An ...We introduce the multiresolution recurrent neu...Multiresolution Recurrent Neural Networks: An ...[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML][0, 7]
3Learning what to share between loosely related...Multi-task learning is motivated by the observ...Learning what to share between loosely related...[stat.ML, cs.AI, cs.CL, cs.LG, cs.NE][0, 7]
4A Deep Reinforcement Learning ChatbotWe present MILABOT: a deep reinforcement learn...A Deep Reinforcement Learning Chatbot $ We pre...[cs.CL, cs.AI, cs.LG, cs.NE, stat.ML][0, 7]
\n", "
" ], "text/plain": [ " title \\\n", "0 Dual Recurrent Attention Units for Visual Ques... \n", "1 Sequential Short-Text Classification with Recu... \n", "2 Multiresolution Recurrent Neural Networks: An ... \n", "3 Learning what to share between loosely related... \n", "4 A Deep Reinforcement Learning Chatbot \n", "\n", " summary \\\n", "0 We propose an architecture for VQA which utili... \n", "1 Recent approaches based on artificial neural n... \n", "2 We introduce the multiresolution recurrent neu... \n", "3 Multi-task learning is motivated by the observ... \n", "4 We present MILABOT: a deep reinforcement learn... \n", "\n", " title_and_summary \\\n", "0 Dual Recurrent Attention Units for Visual Ques... \n", "1 Sequential Short-Text Classification with Recu... \n", "2 Multiresolution Recurrent Neural Networks: An ... \n", "3 Learning what to share between loosely related... \n", "4 A Deep Reinforcement Learning Chatbot $ We pre... \n", "\n", " tags categories_indices \n", "0 [cs.AI, cs.CL, cs.CV, cs.NE, stat.ML] [0, 7] \n", "1 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n", "2 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] \n", "3 [stat.ML, cs.AI, cs.CL, cs.LG, cs.NE] [0, 7] \n", "4 [cs.CL, cs.AI, cs.LG, cs.NE, stat.ML] [0, 7] " ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "full_data_df = pd.DataFrame(records)\n", "print(len(full_data_df))\n", "full_data_df.head(5)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Как видим, Computer science встречается очень часто. А, например, экономика - совсем редко. Значит при обучении экономике логично давать больше вес**" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "text_data = list(full_data_df['title_and_summary'])\n", "categories_indices = list(full_data_df['categories_indices'])" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28700 8200 4100\n" ] } ], "source": [ "X_train_val, X_test, y_train_val, y_test = train_test_split(text_data, categories_indices, test_size=0.1, random_state=42)\n", "X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=2/9, random_state=42)\n", "print(len(X_train), len(X_val), len(X_test))\n", "# Train is 70%, val is 20%, test is 10%" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Посмотрим на распределение категорий в тренировочной выборке" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{0: 27475, 3: 1591, 7: 7417, 5: 623, 2: 152, 4: 840, 6: 43, 1: 9}\n" ] } ], "source": [ "category_to_count = defaultdict(int)\n", "for row in y_train:\n", " for category in row:\n", " category_to_count[category] += 1\n", "category_to_count = dict(category_to_count)\n", "print(category_to_count)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "tokenizer = AlbertTokenizer.from_pretrained(USED_MODEL)\n", "def tokenize_function(text):\n", " return tokenizer(text, padding=\"max_length\", truncation=True)" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "['_MutableMapping__marker', '__abstractmethods__', '__class__', '__class_getitem__', '__contains__', '__copy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__ior__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__ne__', '__new__', '__or__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__ror__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', '__weakref__', '_abc_impl', '_encodings', '_n_sequences', 'char_to_token', 'char_to_word', 'clear', 'convert_to_tensors', 'copy', 'data', 'encodings', 'fromkeys', 'get', 'is_fast', 'items', 'keys', 'n_sequences', 'pop', 'popitem', 'sequence_ids', 'setdefault', 'to', 'token_to_chars', 'token_to_sequence', 'token_to_word', 'tokens', 'update', 'values', 'word_ids', 'word_to_chars', 'word_to_tokens', 'words']\n", "3\n" ] } ], "source": [ "train_encodings = tokenize_function(X_train)\n", "val_encodings = tokenize_function(X_val)\n", "test_encodings = tokenize_function(X_test)\n", "\n", "print(type(train_encodings))\n", "print(dir(train_encodings))\n", "print(len(train_encodings))" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 28700/28700 [00:00<00:00, 506098.43it/s]\n", "100%|██████████| 8200/8200 [00:00<00:00, 533767.25it/s]\n", "100%|██████████| 4100/4100 [00:00<00:00, 516059.37it/s]\n" ] } ], "source": [ "def get_labels(y: List[List[int]]):\n", " labels = np.zeros((len(y), len(category_to_index)))\n", " for i in tqdm(range(len(y))):\n", " labels[i, y[i]] = 1\n", " return labels.tolist()\n", "\n", "labels_train = get_labels(y_train)\n", "labels_val = get_labels(y_val)\n", "labels_test = get_labels(y_test)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "train_encodings['labels'] = labels_train\n", "val_encodings['labels'] = labels_val\n", "test_encodings['labels'] = labels_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "**Я использовал пример отсюда чтобы понимать, какой нужен формат данных https://github.com/NielsRogge/Transformers-Tutorials/blob/master/BERT/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb**" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "train_dataset = Dataset.from_dict(train_encodings)\n", "val_dataset = Dataset.from_dict(val_encodings)\n", "test_dataset = Dataset.from_dict(test_encodings)\n", "\n", "train_dataset.set_format(\"torch\")\n", "val_dataset.set_format(\"torch\")\n", "test_dataset.set_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of AlbertForSequenceClassification were not initialized from the model checkpoint at albert-base-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\n", " USED_MODEL, \n", " problem_type=\"multi_label_classification\", \n", " num_labels=len(category_to_index),\n", " id2label=index_to_category,\n", " label2id=category_to_index\n", ")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "batch_size = 8\n", "metric_name = \"f1\"" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/jarakcyc/.virtualenvs/Tricks/lib/python3.10/site-packages/transformers/training_args.py:1575: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n", " warnings.warn(\n" ] } ], "source": [ "args = TrainingArguments(\n", " output_dir=f'train-{USED_MODEL}',\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " num_train_epochs=5,\n", " weight_decay=0.01,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=metric_name,\n", " push_to_hub=False\n", ")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import f1_score, roc_auc_score, accuracy_score\n", "from transformers import EvalPrediction\n", "import torch\n", " \n", "# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/\n", "def multi_label_metrics(predictions, labels, threshold=0.5):\n", " sigmoid = torch.nn.Sigmoid()\n", " probs = sigmoid(torch.Tensor(predictions))\n", " y_pred = np.zeros(probs.shape)\n", " y_pred[np.where(probs >= threshold)] = 1\n", " y_true = labels\n", " f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')\n", " roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')\n", " accuracy = accuracy_score(y_true, y_pred)\n", " metrics = {'f1': f1_micro_average,\n", " 'roc_auc': roc_auc,\n", " 'accuracy': accuracy}\n", " return metrics\n", "\n", "def compute_metrics(p: EvalPrediction):\n", " preds = p.predictions[0] if isinstance(p.predictions, \n", " tuple) else p.predictions\n", " result = multi_label_metrics(\n", " predictions=preds, \n", " labels=p.label_ids)\n", " return result" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([1.3057e-01, 3.9861e+02, 2.3602e+01, 2.2549e+00, 4.2708e+00, 5.7584e+00,\n", " 8.3430e+01, 4.8369e-01], device='cuda:0')" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "pos_weight=torch.tensor([\n", " len(y_train) / category_to_count[i] / len(category_to_count) for i in range(len(category_to_count))\n", "]).to(DEVICE)\n", "compute_loss_func_ = BCEWithLogitsLoss(pos_weight=pos_weight)\n", "\n", "# Example of custom trainer is taken from https://medium.com/deeplearningmadeeasy/how-to-use-a-custom-loss-with-hugging-face-fc9a1f91b39b\n", "class CustomTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " loss = compute_loss_func_(logits, labels)\n", " return (loss, outputs) if return_outputs else loss\n", "\n", "pos_weight" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_831875/1711637572.py:1: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `CustomTrainer.__init__`. Use `processing_class` instead.\n", " trainer = CustomTrainer(\n" ] } ], "source": [ "trainer = CustomTrainer(\n", " model,\n", " args,\n", " train_dataset=train_dataset,\n", " eval_dataset=val_dataset,\n", " tokenizer=tokenizer,\n", " compute_metrics=compute_metrics\n", ")" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [17940/17940 1:04:03, Epoch 5/5]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossF1Roc AucAccuracy
10.4991000.3870390.8263130.8668030.681951
20.2953000.3860580.8375380.8712690.694878
30.3839000.3540870.8485410.8870790.705122
40.1675000.3752600.8508800.8888220.707561
50.2821000.4093320.8577890.8984140.712927

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=17940, training_loss=0.3050233064819472, metrics={'train_runtime': 3844.3438, 'train_samples_per_second': 37.328, 'train_steps_per_second': 4.667, 'total_flos': 3431411601408000.0, 'train_loss': 0.3050233064819472, 'epoch': 5.0})" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.train()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "

\n", " \n", " \n", " [ 1/1025 : < :]\n", "
\n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "{'eval_loss': 0.40933170914649963,\n", " 'eval_f1': 0.8577886788823161,\n", " 'eval_roc_auc': 0.8984138467714379,\n", " 'eval_accuracy': 0.7129268292682926,\n", " 'eval_runtime': 69.0762,\n", " 'eval_samples_per_second': 118.71,\n", " 'eval_steps_per_second': 14.839,\n", " 'epoch': 5.0}" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(eval_dataset=val_dataset)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'eval_loss': 0.5079951882362366,\n", " 'eval_f1': 0.8536538088776895,\n", " 'eval_roc_auc': 0.8970809313634953,\n", " 'eval_accuracy': 0.708780487804878,\n", " 'eval_runtime': 34.407,\n", " 'eval_samples_per_second': 119.162,\n", " 'eval_steps_per_second': 14.91,\n", " 'epoch': 5.0}" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainer.evaluate(eval_dataset=test_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Исходная задача у нас звучала как \"хотим увидеть топ-95%* тематик, отсортированных по убыванию вероятности\", где под тематиками имелись ввиду категории (физика, биология и так далее)\n", "\n", "Будем делать следующее:\n", "- наша модель выдает логиты категорий\n", "- посчитаем с их помощью вероятность категорий, считая их сумму равной 1 (хотя на самом деле категорий может быть несколько)\n", "- выведем требуемые топ-95% тематик" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\n", " f\"train-{USED_MODEL}/checkpoint-10764\", \n", " problem_type=\"multi_label_classification\", \n", " num_labels=len(category_to_index),\n", " id2label=index_to_category,\n", " label2id=category_to_index\n", ").to(DEVICE)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SequenceClassifierOutput(loss=None, logits=tensor([[ 3.5393, -7.6223, -5.9721, -0.6268, -3.4508, -5.2609, -5.8817, -3.5099]],\n", " device='cuda:0', grad_fn=), hidden_states=None, attentions=None)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function('Maths is cool $ In our article we prove that maths is the coolest subject at school').items()})" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "@torch.no_grad\n", "def get_category_probs_dict(model, title: str, summary: str) -> Dict[str, float]:\n", " text = f'{title} $ {summary}'\n", " category_logits = model(**{key: torch.tensor(value).to(model.device).unsqueeze(0) for key, value in tokenize_function(text).items()}).logits\n", " sigmoid = torch.nn.Sigmoid()\n", " category_probs = sigmoid(category_logits.squeeze().cpu()).numpy()\n", " category_probs /= category_probs.sum()\n", " category_probs_dict = {category: 0.0 for category in set(arxiv_topics_df['category'])}\n", " for index in range(len(index_to_category)):\n", " category_probs_dict[index_to_category[index]] += float(category_probs[index])\n", " return category_probs_dict" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def get_most_probable_keys(probs_dict: Dict[str, float], target_probability: float, print_probabilities: bool) -> List[str]:\n", " current_p = 0\n", " probs_list = sorted([(value, key) for key, value in probs_dict.items()])[::-1]\n", " current_index = 0\n", " answer = []\n", " while current_p <= target_probability:\n", " current_p += probs_list[current_index][0]\n", " if not print_probabilities:\n", " answer.append(probs_list[current_index][1])\n", " else:\n", " answer.append(f'{probs_list[current_index][1]} ({probs_list[current_index][0]})')\n", " current_index += 1\n", " if current_index >= len(probs_list):\n", " break\n", " return answer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Сохраняем модель, чтобы потом можно было её использовать в huggingface space" ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "model.safetensors: 100%|██████████| 46.8M/46.8M [00:06<00:00, 7.25MB/s] \n" ] }, { "data": { "text/plain": [ "CommitInfo(commit_url='https://huggingface.co/bumchik2/train-albert-base-v2-tags-classification/commit/bf0db75a370cbbfeed52bcb5760c9adfea198bf9', commit_message='Upload AlbertForSequenceClassification', commit_description='', oid='bf0db75a370cbbfeed52bcb5760c9adfea198bf9', pr_url=None, repo_url=RepoUrl('https://huggingface.co/bumchik2/train-albert-base-v2-tags-classification', endpoint='https://huggingface.co', repo_type='model', repo_id='bumchik2/train-albert-base-v2-tags-classification'), pr_revision=None, pr_num=None)" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.push_to_hub(f\"bumchik2/train-{USED_MODEL}-tags-classification\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Теперь я могу загружать свою модель оттуда" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(\n", " f\"bumchik2/train-{USED_MODEL}-tags-classification\", \n", " problem_type=\"multi_label_classification\", \n", " num_labels=len(category_to_index),\n", " id2label=index_to_category,\n", " label2id=category_to_index\n", ").to(DEVICE)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Протестируем на нескольких реальных примерах:" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Quantitative Biology (0.4282848834991455)',\n", " 'Statistics (0.34262675046920776)',\n", " 'Computer Science (0.14248277246952057)',\n", " 'Physics (0.034869205206632614)',\n", " 'Mathematics (0.029306704178452492)']" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# правильный ответ Quantitative Biology\n", "get_most_probable_keys(\n", " probs_dict=get_category_probs_dict(\n", " model=model,\n", " title='Simulating cell populations with explicit cell cycle length -- implications to cell cycle dependent tumour therapy',\n", " summary='In this study, we present a stochastic simulation model designed to explicitly incorporate cell cycle length, overcoming limitations associated with classical compartmental models. Our approach employs a delay mechanism to represent the cell cycle, allowing the use of arbitrary distributions for cell cycle lengths. We demonstrate the feasibility of our model by fitting it to experimental data from melanoma cell lines previously studied by Vittadello et al. Notably, our model successfully replicates experimentally observed synchronization phenomena that multi-stage models could not adequately explain. By using a gamma distribution to model cell cycle lengths, we achieved excellent agreement between our simulations and empirical data, while significantly reducing computational complexity and parameter estimation challenges inherent in multi-stage approaches. Our results highlight the importance of explicitly incorporating cell cycle lengths in modeling cell populations, with potential implications for optimizing cell cycle-dependent tumor therapies.'\n", " ),\n", " target_probability=0.95,\n", " print_probabilities=True\n", ")" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['Computer Science (0.48184847831726074)',\n", " 'Physics (0.41395917534828186)',\n", " 'Statistics (0.029943009838461876)',\n", " 'Electrical Engineering and Systems Science (0.028027774766087532)']" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# правильный ответ Physics\n", "get_most_probable_keys(\n", " probs_dict=get_category_probs_dict(\n", " model=model,\n", " title='Performance Improvement of LTS Undulators for Synchrotron Light Sources',\n", " summary='The joint expertise of ANL and FNAL has led to the production of Nb3Sn undulator magnets in operation in the ANL Advanced Photon Source (APS). These magnets showed performance reproducibility close to the short sample limit, and a design field increase of 20% at 820A. However, the long training did not allow obtaining the expected 50% increase of the on-axis magnetic field with respect to the ~1 T produced at 450 A current in the ANL NbTi undulator. To address this, 10-pole long undulator prototypes were fabricated, and CTD-101K was replaced as impregnation material with TELENE, an organic olefin-based thermosetting dicyclopentadiene resin produced by RIMTEC Corporation, Japan. Training and magnet retraining after a thermal cycle were nearly eliminated, with only a couple of quenches needed before reaching short sample limit at over 1,100 A. TELENE will enable operation of Nb3Sn undulators much closer to their short sample limit, expanding the energy range and brightness intensity of light sources. TELENE is Co-60 gamma radiation resistant up to 7-8 MGy, and therefore already applicable to impregnate planar, helical and universal devices operating in lower radiation environments than high energy colliders.'\n", " ),\n", " target_probability=0.95,\n", " print_probabilities=True\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Tricks", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }