{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predicting Regression - Multicomponent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/predicting_regression_multicomponent.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install chemprop from GitHub if running in Google Colab\n", "import os\n", "\n", "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", " try:\n", " import chemprop\n", " except ImportError:\n", " !git clone https://github.com/chemprop/chemprop.git\n", " %cd chemprop\n", " !pip install .\n", " %cd examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from lightning import pytorch as pl\n", "from pathlib import Path\n", "\n", "from chemprop import data, featurizers\n", "from chemprop.models import multi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change model input here" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol+mol.ckpt\" # path to the checkpoint file. \n", "# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MulticomponentMPNN(\n", " (message_passing): MulticomponentMessagePassing(\n", " (blocks): ModuleList(\n", " (0-1): 2 x BondMessagePassing(\n", " (W_i): Linear(in_features=86, out_features=300, bias=False)\n", " (W_h): Linear(in_features=300, out_features=300, bias=False)\n", " (W_o): Linear(in_features=372, out_features=300, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (tau): ReLU()\n", " (V_d_transform): Identity()\n", " (graph_transform): GraphTransform(\n", " (V_transform): Identity()\n", " (E_transform): Identity()\n", " )\n", " )\n", " )\n", " )\n", " (agg): MeanAggregation()\n", " (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (predictor): RegressionFFN(\n", " (ffn): MLP(\n", " (0): Sequential(\n", " (0): Linear(in_features=600, out_features=300, bias=True)\n", " )\n", " (1): Sequential(\n", " (0): ReLU()\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=300, out_features=1, bias=True)\n", " )\n", " )\n", " (criterion): MSE(task_weights=[[1.0]])\n", " (output_transform): UnscaleTransform()\n", " )\n", " (X_d_transform): Identity()\n", " (metrics): ModuleList(\n", " (0-1): 2 x MSE(task_weights=[[1.0]])\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mcmpnn = multi.MulticomponentMPNN.load_from_checkpoint(checkpoint_path)\n", "mcmpnn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change predict input here" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "test_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol+mol\" / \"mol+mol.csv\" # path to your .csv file containing SMILES strings to make predictions for\n", "smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load test smiles" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smilessolventpeakwavs_max
0CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...ClCCl642.0
1C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...ClCCl420.0
2CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...O544.0
3c1ccc2[nH]ccc2c1O290.0
4CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...ClC(Cl)Cl736.0
............
95COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...C1CCOC1359.0
96COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...C1CCCCC1386.0
97CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=OCCO425.0
98Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...c1ccccc1324.0
99Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...ClCCl391.0
\n", "

100 rows × 3 columns

\n", "
" ], "text/plain": [ " smiles solvent peakwavs_max\n", "0 CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... ClCCl 642.0\n", "1 C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... ClCCl 420.0\n", "2 CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... O 544.0\n", "3 c1ccc2[nH]ccc2c1 O 290.0\n", "4 CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... ClC(Cl)Cl 736.0\n", ".. ... ... ...\n", "95 COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... C1CCOC1 359.0\n", "96 COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... C1CCCCC1 386.0\n", "97 CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O CCO 425.0\n", "98 Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... c1ccccc1 324.0\n", "99 Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... ClCCl 391.0\n", "\n", "[100 rows x 3 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_test = pd.read_csv(test_path)\n", "df_test" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get smiles" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([['CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2CCCC)C(=O)N(CCCC)C1=S',\n", " 'ClCCl'],\n", " ['C(=C/c1cnccn1)\\\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3cnccn3)cc2)cc1',\n", " 'ClCCl'],\n", " ['CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+](C)C)cc-3oc2c1',\n", " 'O'],\n", " ['c1ccc2[nH]ccc2c1', 'O'],\n", " ['CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5ccccc5c4C3(C)C)CCCC1=C2c1ccccc1C(=O)O',\n", " 'ClC(Cl)Cl']], dtype=object)" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "smiss = df_test[smiles_columns].values\n", "smiss[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get molecule datapoints" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "n_componenets = len(smiles_columns)\n", "test_datapointss = [[data.MoleculeDatapoint.from_smi(smi) for smi in smiss[:, i]] for i in range(n_componenets)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get molecule datasets" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()\n", "test_dsets = [data.MoleculeDataset(test_datapoints, featurizer) for test_datapoints in test_datapointss]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Get multicomponent dataset and data loader" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "test_mcdset = data.MulticomponentDataset(test_dsets)\n", "test_loader = data.build_dataloader(test_mcdset, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Set up trainer" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicting DataLoader 0: 0%| | 0/2 [00:00\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smilessolventpeakwavs_maxpred
0CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C...ClCCl642.0454.898621
1C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c...ClCCl420.0453.561584
2CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]...O544.0448.694977
3c1ccc2[nH]ccc2c1O290.0448.159760
4CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c...ClC(Cl)Cl736.0456.897003
...............
95COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)...C1CCOC1359.0454.548584
96COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc...C1CCCCC1386.0455.287140
97CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=OCCO425.0453.560364
98Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)...c1ccccc1324.0454.656891
99Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)...ClCCl391.0453.118774
\n", "

100 rows × 4 columns

\n", "" ], "text/plain": [ " smiles solvent \\\n", "0 CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... ClCCl \n", "1 C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... ClCCl \n", "2 CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... O \n", "3 c1ccc2[nH]ccc2c1 O \n", "4 CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... ClC(Cl)Cl \n", ".. ... ... \n", "95 COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... C1CCOC1 \n", "96 COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... C1CCCCC1 \n", "97 CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O CCO \n", "98 Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... c1ccccc1 \n", "99 Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... ClCCl \n", "\n", " peakwavs_max pred \n", "0 642.0 454.898621 \n", "1 420.0 453.561584 \n", "2 544.0 448.694977 \n", "3 290.0 448.159760 \n", "4 736.0 456.897003 \n", ".. ... ... \n", "95 359.0 454.548584 \n", "96 386.0 455.287140 \n", "97 425.0 453.560364 \n", "98 324.0 454.656891 \n", "99 391.0 453.118774 \n", "\n", "[100 rows x 4 columns]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_preds = np.concatenate(test_preds, axis=0)\n", "df_test['pred'] = test_preds\n", "df_test" ] } ], "metadata": { "kernelspec": { "display_name": "chemprop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 4 }