{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Transfer Learning / Pretraining\n", "Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training, as demonstrated in this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/transfer_learning.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install chemprop from GitHub if running in Google Colab\n", "import os\n", "\n", "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", " try:\n", " import chemprop\n", " except ImportError:\n", " !git clone https://github.com/chemprop/chemprop.git\n", " %cd chemprop\n", " !pip install .\n", " %cd examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "from pathlib import Path\n", "\n", "from lightning import pytorch as pl\n", "from sklearn.preprocessing import StandardScaler\n", "import torch\n", "\n", "from chemprop import data, featurizers, models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change data inputs here" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\" # path to your data .csv file\n", "num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading\n", "smiles_column = 'smiles' # name of the column containing SMILES strings\n", "target_columns = ['lipo'] # list of names of the columns containing targets" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load data" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | smiles | \n", "lipo | \n", "
---|---|---|
0 | \n", "Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | \n", "3.54 | \n", "
1 | \n", "COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | \n", "-1.18 | \n", "
2 | \n", "COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | \n", "3.69 | \n", "
3 | \n", "OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | \n", "3.37 | \n", "
4 | \n", "Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | \n", "3.10 | \n", "
... | \n", "... | \n", "... | \n", "
95 | \n", "CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | \n", "2.20 | \n", "
96 | \n", "CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | \n", "2.04 | \n", "
97 | \n", "CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | \n", "4.49 | \n", "
98 | \n", "COc1ccc(Cc2c(N)n[nH]c2N)cc1 | \n", "0.20 | \n", "
99 | \n", "CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | \n", "2.00 | \n", "
100 rows × 2 columns
\n", "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃ Test metric ┃ DataLoader 0 ┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│ test/mse │ 0.9625480771064758 │\n", "└───────────────────────────┴───────────────────────────┘\n", "\n" ], "text/plain": [ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", "│\u001b[36m \u001b[0m\u001b[36m test/mse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.9625480771064758 \u001b[0m\u001b[35m \u001b[0m│\n", "└───────────────────────────┴───────────────────────────┘\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "results = trainer.test(mpnn, test_loader)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Transfer learning with multicomponenent models\n", "Multi-component MPNN models have individual MPNN blocks for each molecule it parses in one input. These MPNN modules can be independently frozen for transfer learning." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Change data inputs here" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol+mol.ckpt\" # path to the checkpoint file. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Change checkpoint model inputs here" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MulticomponentMPNN(\n", " (message_passing): MulticomponentMessagePassing(\n", " (blocks): ModuleList(\n", " (0-1): 2 x BondMessagePassing(\n", " (W_i): Linear(in_features=86, out_features=300, bias=False)\n", " (W_h): Linear(in_features=300, out_features=300, bias=False)\n", " (W_o): Linear(in_features=372, out_features=300, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (tau): ReLU()\n", " (V_d_transform): Identity()\n", " (graph_transform): GraphTransform(\n", " (V_transform): Identity()\n", " (E_transform): Identity()\n", " )\n", " )\n", " )\n", " )\n", " (agg): MeanAggregation()\n", " (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (predictor): RegressionFFN(\n", " (ffn): MLP(\n", " (0): Sequential(\n", " (0): Linear(in_features=600, out_features=300, bias=True)\n", " )\n", " (1): Sequential(\n", " (0): ReLU()\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=300, out_features=1, bias=True)\n", " )\n", " )\n", " (criterion): MSE(task_weights=[[1.0]])\n", " (output_transform): UnscaleTransform()\n", " )\n", " (X_d_transform): Identity()\n", " (metrics): ModuleList(\n", " (0-1): 2 x MSE(task_weights=[[1.0]])\n", " )\n", ")" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mpnn_cls = models.MulticomponentMPNN\n", "mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)\n", "mcmpnn" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "blocks_to_freeze = [0, 1] # a list of indices of the individual MPNN blocks to freeze before training." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)\n", "for i in blocks_to_freeze:\n", " mp_block = mcmpnn.message_passing.blocks[i]\n", " mp_block.apply(lambda module: module.requires_grad_(False))\n", " mp_block.eval()\n", "mcmpnn.bn.apply(lambda module: module.requires_grad_(False))\n", "mcmpnn.bn.eval()" ] } ], "metadata": { "kernelspec": { "display_name": "chemprop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 4 }