{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Predicting Regression - Reaction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/predicting_regression_reaction.ipynb)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install chemprop from GitHub if running in Google Colab\n", "import os\n", "\n", "if os.getenv(\"COLAB_RELEASE_TAG\"):\n", " try:\n", " import chemprop\n", " except ImportError:\n", " !git clone https://github.com/chemprop/chemprop.git\n", " %cd chemprop\n", " !pip install .\n", " %cd examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Import packages" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "from lightning import pytorch as pl\n", "from pathlib import Path\n", "\n", "from chemprop import data, featurizers, models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change model input here" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "checkpoint_path = chemprop_dir / \"tests\" / \"data\" / \"example_model_v2_regression_rxn.ckpt\" # path to the checkpoint file.\n", "# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load model" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "MPNN(\n", " (message_passing): BondMessagePassing(\n", " (W_i): Linear(in_features=134, out_features=300, bias=False)\n", " (W_h): Linear(in_features=300, out_features=300, bias=False)\n", " (W_o): Linear(in_features=406, out_features=300, bias=True)\n", " (dropout): Dropout(p=0.0, inplace=False)\n", " (tau): ReLU()\n", " (V_d_transform): Identity()\n", " (graph_transform): GraphTransform(\n", " (V_transform): Identity()\n", " (E_transform): Identity()\n", " )\n", " )\n", " (agg): MeanAggregation()\n", " (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", " (predictor): RegressionFFN(\n", " (ffn): MLP(\n", " (0): Sequential(\n", " (0): Linear(in_features=300, out_features=300, bias=True)\n", " )\n", " (1): Sequential(\n", " (0): ReLU()\n", " (1): Dropout(p=0.0, inplace=False)\n", " (2): Linear(in_features=300, out_features=1, bias=True)\n", " )\n", " )\n", " (criterion): MSE(task_weights=[[1.0]])\n", " (output_transform): UnscaleTransform()\n", " )\n", " (X_d_transform): Identity()\n", " (metrics): ModuleList(\n", " (0-1): 2 x MSE(task_weights=[[1.0]])\n", " )\n", ")" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "mpnn = models.MPNN.load_from_checkpoint(checkpoint_path)\n", "mpnn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Change predict input here" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "chemprop_dir = Path.cwd().parent\n", "test_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"rxn\" / \"rxn.csv\"\n", "smiles_column = 'smiles'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load smiles" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:15])([H:13])[H:14])([H:11])[H:12])([H:9])[H:10])[H:8]>>[C:3](=[C:4]=[O:5])([H:11])[H:12].[C:6]([O:7][H:15])([H:8])([H:13])[H:14].[O:1]=[C:2]([H:9])[H:10]',\n", " '[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:4]3([H:11])[O:5][C@:6]1([H:12])[C@@:7]23[H:13]>>[C:1]1([H:8])([H:9])[O:2][C:3]([H:10])=[C:7]([H:13])[C@:6]1([O+:5]=[C-:4][H:11])[H:12]',\n", " '[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H:13])([H:14])[C:5]([H:15])=[C:6]([H:16])[C@@:7]12[H:17])([H:8])([H:9])[H:10]>>[C:1]([C@@:2]1([H:11])[C:3]([H:12])([H:13])[C:4]([H:14])=[C:5]([H:15])[C:6]([H:16])=[C:7]1[H:17])([H:8])([H:9])[H:10]',\n", " '[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])([H:11])[H:12])([H:8])([H:9])[H:10]>>[C-:1]([O+:2]=[C:3]([C@@:4]([C:5]([H:14])([H:15])[H:16])([C:6]([O:7][H:19])([H:17])[H:18])[H:13])[H:12])([H:8])[H:10].[H:9][H:11]',\n", " '[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H:10])[H:11])([H:7])([H:8])[H:9]>>[C:1]([C:2](=[C:3]=[C:4]([H:10])[H:11])[C:5](=[O:6])[H:12])([H:7])([H:8])[H:9]'],\n", " dtype=object)" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_test = pd.read_csv(test_path)\n", "\n", "smis = df_test.loc[:, smiles_column].values\n", "smis[:5]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load datapoints" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "test_data = [data.ReactionDatapoint.from_smi(smi) for smi in smis]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define featurizer" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "featurizer = featurizers.CondensedGraphOfReactionFeaturizer(mode_=\"PROD_DIFF\")\n", "# Testing parameters should match training parameters" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Get dataset and dataloader" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "test_dset = data.ReactionDataset(test_data, featurizer=featurizer)\n", "test_loader = data.build_dataloader(test_dset, shuffle=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Perform tests" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: False\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", "/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Predicting DataLoader 0: 100%|███████████████████| 2/2 [00:00<00:00, 119.42it/s]\n" ] } ], "source": [ "with torch.inference_mode():\n", " trainer = pl.Trainer(\n", " logger=None,\n", " enable_progress_bar=True,\n", " accelerator=\"cpu\",\n", " devices=1\n", " )\n", " test_preds = trainer.predict(mpnn, test_loader)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
smileseapreds
0[O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:1...8.8989348.071494
1[C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:...5.4643288.108090
2[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...5.2705528.087680
3[C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])...8.4730068.070966
4[C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H...5.5790378.065533
............
95[C:1]([C:2]([C:3]([H:12])([H:13])[H:14])([C:4]...9.2956658.071316
96[O:1]=[C:2]([C@@:3]1([H:9])[C:4]([H:10])([H:11...7.7534428.085133
97[C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H...10.6502158.096391
98[C:1]1([H:8])([H:9])[C@@:2]2([H:10])[N:3]1[C:4...10.1389458.202709
99[C:1]([C@@:2]1([C:3]([C:4]([O:5][H:15])([H:13]...6.9799348.107012
\n", "

100 rows × 3 columns

\n", "
" ], "text/plain": [ " smiles ea preds\n", "0 [O:1]([C:2]([C:3]([C:4](=[O:5])[C:6]([O:7][H:1... 8.898934 8.071494\n", "1 [C:1]1([H:8])([H:9])[O:2][C@@:3]2([H:10])[C@@:... 5.464328 8.108090\n", "2 [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H... 5.270552 8.087680\n", "3 [C:1]([O:2][C:3]([C@@:4]([C:5]([H:14])([H:15])... 8.473006 8.070966\n", "4 [C:1]([C:2]#[C:3][C:4]([C:5](=[O:6])[H:12])([H... 5.579037 8.065533\n", ".. ... ... ...\n", "95 [C:1]([C:2]([C:3]([H:12])([H:13])[H:14])([C:4]... 9.295665 8.071316\n", "96 [O:1]=[C:2]([C@@:3]1([H:9])[C:4]([H:10])([H:11... 7.753442 8.085133\n", "97 [C:1]([C@@:2]1([H:11])[C@@:3]2([H:12])[C:4]([H... 10.650215 8.096391\n", "98 [C:1]1([H:8])([H:9])[C@@:2]2([H:10])[N:3]1[C:4... 10.138945 8.202709\n", "99 [C:1]([C@@:2]1([C:3]([C:4]([O:5][H:15])([H:13]... 6.979934 8.107012\n", "\n", "[100 rows x 3 columns]" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_preds = np.concatenate(test_preds, axis=0)\n", "df_test['preds'] = test_preds\n", "df_test" ] } ], "metadata": { "kernelspec": { "display_name": "chemprop", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 4 }