{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Transfer Learning / Pretraining\n", "Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training, as demonstrated in this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/transfer_learning.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install chemprop from GitHub if running in Google Colab\n", "import os\n", "\n", "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", " try:\n", " import chemprop\n", " except ImportError:\n", " !git clone https://github.com/chemprop/chemprop.git\n", " %cd chemprop\n", " !pip install .\n", " %cd examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from pathlib import Path\n", "\n", "from lightning import pytorch as pl\n", "from sklearn.preprocessing import StandardScaler\n", "import torch\n", "\n", "from chemprop import data, featurizers, models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change data inputs here" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n", "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n", "smiles_column = 'smiles' # name of the column containing SMILES strings\n", "target_columns = ['lipo'] # list of names of the columns containing targets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smileslipo
0Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc143.54
1COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...-1.18
2COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl3.69
3OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...3.37
4Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...3.10
.........
95CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...2.20
96CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...2.04
97CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...4.49
98COc1ccc(Cc2c(N)n[nH]c2N)cc10.20
99CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(...2.00
\n", "

100 rows × 2 columns

\n", "
" ], "text/plain": [ " smiles lipo\n", "0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 3.54\n", "1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n", "2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl 3.69\n", "3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... 3.37\n", "4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... 3.10\n", ".. ... ...\n", "95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... 2.20\n", "96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... 2.04\n", "97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... 4.49\n", "98 COc1ccc(Cc2c(N)n[nH]c2N)cc1 0.20\n", "99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... 2.00\n", "\n", "[100 rows x 2 columns]" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_input = pd.read_csv(input_path)\n", "df_input" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get SMILES and targets" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "smis = df_input.loc[:, smiles_column].values\n", "ys = df_input.loc[:, target_columns].values" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',\n", " 'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',\n", " 'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',\n", " 'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',\n", " 'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],\n", " dtype=object)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "smis[:5] # show first 5 SMILES strings" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([[ 3.54],\n", " [-1.18],\n", " [ 3.69],\n", " [ 3.37],\n", " [ 3.1 ]])" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ys[:5] # show first 5 targets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get molecule datapoints" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Perform data splitting for training, validation, and testing" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['SCAFFOLD_BALANCED',\n", " 'RANDOM_WITH_REPEATED_SMILES',\n", " 'RANDOM',\n", " 'KENNARD_STONE',\n", " 'KMEANS']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# available split types\n", "list(data.SplitType.keys())" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "mols = [d.mol for d in all_data] # RDkit Mol objects are use for structure based splits\n", "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n", "train_data, val_data, test_data = data.split_data_by_indices(\n", " all_data, train_indices, val_indices, test_indices\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change checkpoint model inputs here\n", "Both message-passing neural networks (MPNNs) and multi-component MPNNs can have their weights initialized from a checkpoint file." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol.ckpt\" # path to the checkpoint file.\n", "# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`." ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "mpnn_cls = models.MPNN" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MPNN(\n", " (message_passing): BondMessagePassing(\n", " (W_i): Linear(in_features=86, out_features=300, bias=False)\n", " (W_h): Linear(in_features=300, out_features=300, bias=False)\n", " (W_o): Linear(in_features=372, out_features=300, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (tau): ReLU()\n", " (V_d_transform): Identity()\n", " (graph_transform): GraphTransform(\n", " (V_transform): Identity()\n", " (E_transform): Identity()\n", " )\n", " )\n", " (agg): MeanAggregation()\n", " (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (predictor): RegressionFFN(\n", " (ffn): MLP(\n", " (0): Sequential(\n", " (0): Linear(in_features=300, out_features=300, bias=True)\n", " )\n", " (1): Sequential(\n", " (0): ReLU()\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=300, out_features=1, bias=True)\n", " )\n", " )\n", " (criterion): MSE(task_weights=[[1.0]])\n", " (output_transform): UnscaleTransform()\n", " )\n", " (X_d_transform): Identity()\n", " (metrics): ModuleList(\n", " (0-1): 2 x MSE(task_weights=[[1.0]])\n", " )\n", ")" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mpnn = mpnn_cls.load_from_file(checkpoint_path)\n", "mpnn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Scale fine-tuning data with the model's target scaler\n", "\n", "If the pre-trained model was a regression model, it probably was trained on a scaled dataset. The scaler is saved as part of the model and used during prediction. For furthur training, we need to scale the fine-tuning data with the same target scaler." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "pretraining_scaler = StandardScaler()\n", "pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()\n", "pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get MoleculeDataset" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n", "\n", "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n", "train_dset.normalize_targets(pretraining_scaler)\n", "\n", "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n", "val_dset.normalize_targets(pretraining_scaler)\n", "\n", "test_dset = data.MoleculeDataset(test_data[0], featurizer)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get DataLoader" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "train_loader = data.build_dataloader(train_dset, num_workers=num_workers)\n", "val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)\n", "test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Freezing MPNN and FFN layers\n", "Certain layers of a pre-trained model can be kept unchanged during further training on a new task." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Freezing the MPNN" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mpnn.message_passing.apply(lambda module: module.requires_grad_(False))\n", "mpnn.message_passing.eval()\n", "mpnn.bn.apply(lambda module: module.requires_grad_(False))\n", "mpnn.bn.eval() # Set batch norm layers to eval mode to freeze running mean and running var." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Freezing FFN layers" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "frzn_ffn_layers = 1 # the number of consecutive FFN layers to freeze." ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "for idx in range(frzn_ffn_layers):\n", " mpnn.predictor.ffn[idx].requires_grad_(False)\n", " mpnn.predictor.ffn[idx + 1].eval()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set up trainer" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n" ] } ], "source": [ "trainer = pl.Trainer(\n", " logger=False,\n", " enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.\n", " enable_progress_bar=True,\n", " accelerator=\"auto\",\n", " devices=1,\n", " max_epochs=20, # number of epochs to train for\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Start training" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/brianli/Documents/chemprop/examples/checkpoints exists and is not empty.\n", "Loading `train_dataloader` to estimate number of stepping batches.\n", "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n", "\n", " | Name | Type | Params | Mode \n", "---------------------------------------------------------------\n", "0 | message_passing | BondMessagePassing | 227 K | eval \n", "1 | agg | MeanAggregation | 0 | train\n", "2 | bn | BatchNorm1d | 600 | eval \n", "3 | predictor | RegressionFFN | 90.6 K | train\n", "4 | X_d_transform | Identity | 0 | train\n", "5 | metrics | ModuleList | 0 | train\n", "---------------------------------------------------------------\n", "301 Trainable params\n", "318 K Non-trainable params\n", "318 K Total params\n", "1.276 Total estimated model params size (MB)\n", "11 Modules in train mode\n", "15 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Sanity Checking DataLoader 0: 0%| | 0/1 [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Test metric DataLoader 0 ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ test/mse 0.9625480771064758 │\n", "└───────────────────────────┴───────────────────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│\u001b[36m \u001b[0m\u001b[36m test/mse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9625480771064758 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "results = trainer.test(mpnn, test_loader)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Transfer learning with multicomponenent models\n", "Multi-component MPNN models have individual MPNN blocks for each molecule it parses in one input. These MPNN modules can be independently frozen for transfer learning." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Change data inputs here" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol+mol.ckpt\" # path to the checkpoint file. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Change checkpoint model inputs here" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MulticomponentMPNN(\n", " (message_passing): MulticomponentMessagePassing(\n", " (blocks): ModuleList(\n", " (0-1): 2 x BondMessagePassing(\n", " (W_i): Linear(in_features=86, out_features=300, bias=False)\n", " (W_h): Linear(in_features=300, out_features=300, bias=False)\n", " (W_o): Linear(in_features=372, out_features=300, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (tau): ReLU()\n", " (V_d_transform): Identity()\n", " (graph_transform): GraphTransform(\n", " (V_transform): Identity()\n", " (E_transform): Identity()\n", " )\n", " )\n", " )\n", " )\n", " (agg): MeanAggregation()\n", " (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (predictor): RegressionFFN(\n", " (ffn): MLP(\n", " (0): Sequential(\n", " (0): Linear(in_features=600, out_features=300, bias=True)\n", " )\n", " (1): Sequential(\n", " (0): ReLU()\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=300, out_features=1, bias=True)\n", " )\n", " )\n", " (criterion): MSE(task_weights=[[1.0]])\n", " (output_transform): UnscaleTransform()\n", " )\n", " (X_d_transform): Identity()\n", " (metrics): ModuleList(\n", " (0-1): 2 x MSE(task_weights=[[1.0]])\n", " )\n", ")" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mpnn_cls = models.MulticomponentMPNN\n", "mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)\n", "mcmpnn" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "blocks_to_freeze = [0, 1] # a list of indices of the individual MPNN blocks to freeze before training." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)\n", "for i in blocks_to_freeze:\n", " mp_block = mcmpnn.message_passing.blocks[i]\n", " mp_block.apply(lambda module: module.requires_grad_(False))\n", " mp_block.eval()\n", "mcmpnn.bn.apply(lambda module: module.requires_grad_(False))\n", "mcmpnn.bn.eval()" ] } ], "metadata": { "kernelspec": { "display_name": "chemprop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 4 }