{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predicting Regression - Multicomponent" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/predicting_regression_multicomponent.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install chemprop from GitHub if running in Google Colab\n", "import os\n", "\n", "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", " try:\n", " import chemprop\n", " except ImportError:\n", " !git clone https://github.com/chemprop/chemprop.git\n", " %cd chemprop\n", " !pip install .\n", " %cd examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from lightning import pytorch as pl\n", "from pathlib import Path\n", "\n", "from chemprop import data, featurizers\n", "from chemprop.models import multi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change model input here" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_mol+mol.ckpt\" # path to the checkpoint file. \n", "# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MulticomponentMPNN(\n", " (message_passing): MulticomponentMessagePassing(\n", " (blocks): ModuleList(\n", " (0-1): 2 x BondMessagePassing(\n", " (W_i): Linear(in_features=86, out_features=300, bias=False)\n", " (W_h): Linear(in_features=300, out_features=300, bias=False)\n", " (W_o): Linear(in_features=372, out_features=300, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (tau): ReLU()\n", " (V_d_transform): Identity()\n", " (graph_transform): GraphTransform(\n", " (V_transform): Identity()\n", " (E_transform): Identity()\n", " )\n", " )\n", " )\n", " )\n", " (agg): MeanAggregation()\n", " (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (predictor): RegressionFFN(\n", " (ffn): MLP(\n", " (0): Sequential(\n", " (0): Linear(in_features=600, out_features=300, bias=True)\n", " )\n", " (1): Sequential(\n", " (0): ReLU()\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=300, out_features=1, bias=True)\n", " )\n", " )\n", " (criterion): MSE(task_weights=[[1.0]])\n", " (output_transform): UnscaleTransform()\n", " )\n", " (X_d_transform): Identity()\n", " (metrics): ModuleList(\n", " (0-1): 2 x MSE(task_weights=[[1.0]])\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mcmpnn = multi.MulticomponentMPNN.load_from_checkpoint(checkpoint_path)\n", "mcmpnn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change predict input here" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "test_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol+mol\" / \"mol+mol.csv\" # path to your .csv file containing SMILES strings to make predictions for\n", "smiles_columns = ['smiles', 'solvent'] # name of the column containing SMILES strings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load test smiles" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | smiles | \n", "solvent | \n", "peakwavs_max | \n", "
---|---|---|---|
0 | \n", "CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... | \n", "ClCCl | \n", "642.0 | \n", "
1 | \n", "C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... | \n", "ClCCl | \n", "420.0 | \n", "
2 | \n", "CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... | \n", "O | \n", "544.0 | \n", "
3 | \n", "c1ccc2[nH]ccc2c1 | \n", "O | \n", "290.0 | \n", "
4 | \n", "CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... | \n", "ClC(Cl)Cl | \n", "736.0 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... | \n", "C1CCOC1 | \n", "359.0 | \n", "
96 | \n", "COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... | \n", "C1CCCCC1 | \n", "386.0 | \n", "
97 | \n", "CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O | \n", "CCO | \n", "425.0 | \n", "
98 | \n", "Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... | \n", "c1ccccc1 | \n", "324.0 | \n", "
99 | \n", "Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... | \n", "ClCCl | \n", "391.0 | \n", "
100 rows × 3 columns
\n", "\n", " | smiles | \n", "solvent | \n", "peakwavs_max | \n", "pred | \n", "
---|---|---|---|---|
0 | \n", "CCCCN1C(=O)C(=C/C=C/C=C/C=C2N(CCCC)c3ccccc3N2C... | \n", "ClCCl | \n", "642.0 | \n", "454.898621 | \n", "
1 | \n", "C(=C/c1cnccn1)\\c1ccc(N(c2ccccc2)c2ccc(/C=C/c3c... | \n", "ClCCl | \n", "420.0 | \n", "453.561584 | \n", "
2 | \n", "CN(C)c1ccc2c(-c3ccc(N)cc3C(=O)[O-])c3ccc(=[N+]... | \n", "O | \n", "544.0 | \n", "448.694977 | \n", "
3 | \n", "c1ccc2[nH]ccc2c1 | \n", "O | \n", "290.0 | \n", "448.159760 | \n", "
4 | \n", "CCN(CC)c1ccc2c(c1)OC1=C(/C=C/C3=[N+](C)c4ccc5c... | \n", "ClC(Cl)Cl | \n", "736.0 | \n", "456.897003 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "COc1ccc(C2CC(c3ccc(O)cc3)=NN2c2ccc(S(N)(=O)=O)... | \n", "C1CCOC1 | \n", "359.0 | \n", "454.548584 | \n", "
96 | \n", "COc1ccc2c3c(c4ccc(OC)cc4c2c1)C1(c2ccccc2-c2ccc... | \n", "C1CCCCC1 | \n", "386.0 | \n", "455.287140 | \n", "
97 | \n", "CCCCOc1c(C=C2N(C)c3ccccc3C2(C)C)c(=O)c1=O | \n", "CCO | \n", "425.0 | \n", "453.560364 | \n", "
98 | \n", "Cc1cc2ccc(-c3cccc4cccc(-c5ccc6cc(C)c(=O)oc6c5)... | \n", "c1ccccc1 | \n", "324.0 | \n", "454.656891 | \n", "
99 | \n", "Cc1ccc(C(=O)c2c(C)c3ccc4cccc5c6cccc7ccc2c(c76)... | \n", "ClCCl | \n", "391.0 | \n", "453.118774 | \n", "
100 rows × 4 columns
\n", "