{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Uncertainty Quantification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/uncertainty.ipynb)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install chemprop from GitHub if running in Google Colab\n",
"import os\n",
"\n",
"if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
" try:\n",
" import chemprop\n",
" except ImportError:\n",
" !git clone https://github.com/chemprop/chemprop.git\n",
" %cd chemprop\n",
" !pip install .\n",
" %cd examples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Import packages"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import torch\n",
"import pandas as pd\n",
"from pathlib import Path\n",
"\n",
"from lightning import pytorch as pl\n",
"from lightning.pytorch.callbacks import ModelCheckpoint\n",
"\n",
"from chemprop import data, models, nn, uncertainty\n",
"from chemprop.models import save_model, load_model\n",
"from chemprop.cli.conf import NOW\n",
"from chemprop.cli.predict import find_models\n",
"\n",
"%load_ext autoreload\n",
"%autoreload 2"
]
},
{
"cell_type": "markdown",
"metadata": {
"tags": []
},
"source": [
"# Training"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loda data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"chemprop_dir = Path.cwd().parent\n",
"input_path = (\n",
" chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\"\n",
") # path to your data .csv file\n",
"df_input = pd.read_csv(input_path)\n",
"smis = df_input.loc[:, \"smiles\"].values\n",
"ys = df_input.loc[:, [\"lipo\"]].values\n",
"all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"mols = [d.mol for d in all_data] # RDkit Mol objects are use for structure based splits\n",
"train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
"train_data, val_data, test_data = data.split_data_by_indices(\n",
" all_data, train_indices, val_indices, test_indices\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"train_dset = data.MoleculeDataset(train_data[0])\n",
"scaler = train_dset.normalize_targets()\n",
"\n",
"val_dset = data.MoleculeDataset(val_data[0])\n",
"val_dset.normalize_targets(scaler)\n",
"\n",
"test_dset = data.MoleculeDataset(test_data[0])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"train_loader = data.build_dataloader(train_dset)\n",
"val_loader = data.build_dataloader(val_dset, shuffle=False)\n",
"test_loader = data.build_dataloader(test_dset, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Constructs MPNN\n",
"\n",
"- A `Message passing` constructs molecular graphs using message passing to learn node-level hidden representations.\n",
"\n",
"- An `Aggregation` is responsible for constructing a graph-level representation from the set of node-level representations after message passing.\n",
"\n",
"- A `FFN` takes the aggregated representations and make target predictions. To obtain uncertainty predictions, the `FFN` must be modified accordingly.\n",
"\n",
" For regression:\n",
" - `ffn = nn.RegressionFFN()`\n",
" - `ffn = nn.MveFFN()`\n",
" - `ffn = nn.EvidentialFFN()`\n",
"\n",
" For classification:\n",
" - `ffn = nn.BinaryClassificationFFN()`\n",
" - `ffn = nn.BinaryDirichletFFN()`\n",
" - `ffn = nn.MulticlassClassificationFFN()`\n",
" - `ffn = nn.MulticlassDirichletFFN()`\n",
"\n",
" For spectral:\n",
" - `ffn = nn.SpectralFFN()` # will be available in future version"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"MPNN(\n",
" (message_passing): BondMessagePassing(\n",
" (W_i): Linear(in_features=86, out_features=300, bias=False)\n",
" (W_h): Linear(in_features=300, out_features=300, bias=False)\n",
" (W_o): Linear(in_features=372, out_features=300, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" (tau): ReLU()\n",
" (V_d_transform): Identity()\n",
" (graph_transform): Identity()\n",
" )\n",
" (agg): MeanAggregation()\n",
" (bn): Identity()\n",
" (predictor): MveFFN(\n",
" (ffn): MLP(\n",
" (0): Sequential(\n",
" (0): Linear(in_features=300, out_features=300, bias=True)\n",
" )\n",
" (1): Sequential(\n",
" (0): ReLU()\n",
" (1): Dropout(p=0.0, inplace=False)\n",
" (2): Linear(in_features=300, out_features=2, bias=True)\n",
" )\n",
" )\n",
" (criterion): MVELoss(task_weights=[[1.0]])\n",
" (output_transform): UnscaleTransform()\n",
" )\n",
" (X_d_transform): Identity()\n",
" (metrics): ModuleList(\n",
" (0): MSE(task_weights=[[1.0]])\n",
" (1): MVELoss(task_weights=[[1.0]])\n",
" )\n",
")"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"mp = nn.BondMessagePassing()\n",
"agg = nn.MeanAggregation()\n",
"output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)\n",
"# Change to other predictor if needed.\n",
"ffn = nn.MveFFN(output_transform=output_transform)\n",
"mpnn = models.MPNN(mp, agg, ffn, batch_norm=False)\n",
"mpnn"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Set up trainer"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model_output_dir = Path(f\"chemprop_training/{NOW}\")\n",
"monitor_mode = \"min\" if mpnn.metrics[0].higher_is_better else \"max\"\n",
"checkpointing = ModelCheckpoint(\n",
" model_output_dir / \"checkpoints\",\n",
" \"best-{epoch}-{val_loss:.2f}\",\n",
" \"val_loss\",\n",
" mode=monitor_mode,\n",
" save_last=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (mps), used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n"
]
}
],
"source": [
"trainer = pl.Trainer(\n",
" logger=False,\n",
" enable_checkpointing=True,\n",
" enable_progress_bar=False,\n",
" accelerator=\"cpu\",\n",
" callbacks=[checkpointing],\n",
" devices=1,\n",
" max_epochs=20,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start training"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"scrolled": true,
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading `train_dataloader` to estimate number of stepping batches.\n",
"/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
"\n",
" | Name | Type | Params | Mode \n",
"---------------------------------------------------------------\n",
"0 | message_passing | BondMessagePassing | 227 K | train\n",
"1 | agg | MeanAggregation | 0 | train\n",
"2 | bn | Identity | 0 | train\n",
"3 | predictor | MveFFN | 90.9 K | train\n",
"4 | X_d_transform | Identity | 0 | train\n",
"5 | metrics | ModuleList | 0 | train\n",
"---------------------------------------------------------------\n",
"318 K Trainable params\n",
"0 Non-trainable params\n",
"318 K Total params\n",
"1.274 Total estimated model params size (MB)\n",
"24 Modules in train mode\n",
"0 Modules in eval mode\n",
"/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n",
"`Trainer.fit` stopped: `max_epochs=20` reached.\n"
]
}
],
"source": [
"trainer.fit(mpnn, train_loader, val_loader)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Save the best model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"best_model_path = checkpointing.best_model_path\n",
"model = mpnn.__class__.load_from_checkpoint(best_model_path)\n",
"p_model = model_output_dir / \"best.pt\"\n",
"save_model(p_model, model)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Predicting"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Change model input here"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" smiles | \n",
" lipo | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 | \n",
" 3.54 | \n",
"
\n",
" \n",
" 1 | \n",
" COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... | \n",
" -1.18 | \n",
"
\n",
" \n",
" 2 | \n",
" COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl | \n",
" 3.69 | \n",
"
\n",
" \n",
" 3 | \n",
" OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... | \n",
" 3.37 | \n",
"
\n",
" \n",
" 4 | \n",
" Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... | \n",
" 3.10 | \n",
"
\n",
" \n",
" ... | \n",
" ... | \n",
" ... | \n",
"
\n",
" \n",
" 95 | \n",
" CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... | \n",
" 2.20 | \n",
"
\n",
" \n",
" 96 | \n",
" CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... | \n",
" 2.04 | \n",
"
\n",
" \n",
" 97 | \n",
" CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... | \n",
" 4.49 | \n",
"
\n",
" \n",
" 98 | \n",
" COc1ccc(Cc2c(N)n[nH]c2N)cc1 | \n",
" 0.20 | \n",
"
\n",
" \n",
" 99 | \n",
" CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... | \n",
" 2.00 | \n",
"
\n",
" \n",
"
\n",
"
100 rows × 2 columns
\n",
"
"
],
"text/plain": [
" smiles lipo\n",
"0 Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14 3.54\n",
"1 COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)... -1.18\n",
"2 COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl 3.69\n",
"3 OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C... 3.37\n",
"4 Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N... 3.10\n",
".. ... ...\n",
"95 CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C... 2.20\n",
"96 CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)... 2.04\n",
"97 CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)... 4.49\n",
"98 COc1ccc(Cc2c(N)n[nH]c2N)cc1 0.20\n",
"99 CCN(CCN(C)C)S(=O)(=O)c1ccc(cc1)c2cnc(N)c(n2)C(... 2.00\n",
"\n",
"[100 rows x 2 columns]"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chemprop_dir = Path.cwd().parent\n",
"test_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\"\n",
"df_test = pd.read_csv(test_path)\n",
"test_dset = data.MoleculeDataset(test_data[0])\n",
"test_loader = data.build_dataloader(test_dset, shuffle=False)\n",
"df_test"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# use the validation set from the training as the calibration set as an example\n",
"cal_dset = data.MoleculeDataset(val_data[0])\n",
"cal_loader = data.build_dataloader(cal_dset, shuffle=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Constructs uncertainty estimator\n",
"An uncertianty estimator can make model predictions and associated uncertainty predictions.\n",
"\n",
"Available options can be found in `uncertainty.UncertaintyEstimatorRegistry`."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClassRegistry {\n",
" 'none': ,\n",
" 'mve': ,\n",
" 'ensemble': ,\n",
" 'classification': ,\n",
" 'evidential-total': ,\n",
" 'evidential-epistemic': ,\n",
" 'evidential-aleatoric': ,\n",
" 'dropout': ,\n",
" 'classification-dirichlet': ,\n",
" 'multiclass-dirichlet': ,\n",
" 'quantile-regression': \n",
"}\n"
]
}
],
"source": [
"print(uncertainty.UncertaintyEstimatorRegistry)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"unc_estimator = uncertainty.MVEEstimator()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Constructs uncertainty calibrator\n",
"An uncertianty calibrator can calibrate the predicted uncertainties.\n",
"\n",
"Available options can be found in `uncertainty.UncertaintyCalibratorRegistry`.\n",
"\n",
"For regression:\n",
"\n",
"- ZScalingCalibrator\n",
"\n",
"- ZelikmanCalibrator\n",
"\n",
"- MVEWeightingCalibrator\n",
"\n",
"- RegressionConformalCalibrator\n",
"\n",
"For binary classification:\n",
"\n",
"- PlattCalibrator\n",
"\n",
"- IsotonicCalibrator\n",
"\n",
"- MultilabelConformalCalibrator\n",
"\n",
"For multiclass classification:\n",
"\n",
"- MulticlassConformalCalibrator\n",
"\n",
"- AdaptiveMulticlassConformalCalibrator\n",
"\n",
"- IsotonicMulticlassCalibrator"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClassRegistry {\n",
" 'zscaling': ,\n",
" 'zelikman-interval': ,\n",
" 'mve-weighting': ,\n",
" 'conformal-regression': ,\n",
" 'platt': ,\n",
" 'isotonic': ,\n",
" 'conformal-multilabel': ,\n",
" 'conformal-multiclass': ,\n",
" 'conformal-adaptive': ,\n",
" 'isotonic-multiclass': \n",
"}\n"
]
}
],
"source": [
"print(uncertainty.UncertaintyCalibratorRegistry)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"unc_calibrator = uncertainty.ZScalingCalibrator()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Constructs uncertainty evaluator\n",
"An uncertianty evaluator can evaluates the quality of uncertainty estimates.\n",
"\n",
"Available options can be found in `uncertainty.UncertaintyEvaluatorRegistry`.\n",
"\n",
"For regression:\n",
"\n",
"- NLLRegressionEvaluator\n",
"\n",
"- CalibrationAreaEvaluator\n",
"\n",
"- ExpectedNormalizedErrorEvaluator\n",
"\n",
"- SpearmanEvaluator\n",
"\n",
"- RegressionConformalEvaluator\n",
"\n",
"For binary classification:\n",
"\n",
"- NLLClassEvaluator\n",
"\n",
"- MultilabelConformalEvaluator\n",
"\n",
"\n",
"For multiclass classification:\n",
"\n",
"- NLLMulticlassEvaluator\n",
"\n",
"- MulticlassConformalEvaluator"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ClassRegistry {\n",
" 'nll-regression': ,\n",
" 'miscalibration_area': ,\n",
" 'ence': ,\n",
" 'spearman': ,\n",
" 'conformal-coverage-regression': ,\n",
" 'nll-classification': ,\n",
" 'conformal-coverage-classification': ,\n",
" 'nll-multiclass': ,\n",
" 'conformal-coverage-multiclass': \n",
"}\n"
]
}
],
"source": [
"print(uncertainty.UncertaintyEvaluatorRegistry)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"unc_evaluators = [\n",
" uncertainty.NLLRegressionEvaluator(),\n",
" uncertainty.CalibrationAreaEvaluator(),\n",
" uncertainty.ExpectedNormalizedErrorEvaluator(),\n",
" uncertainty.SpearmanEvaluator(),\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load model"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"model_paths = find_models([model_output_dir])\n",
"models = [load_model(model_path, multicomponent=False) for model_path in model_paths]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup trainer"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"GPU available: True (mps), used: False\n",
"TPU available: False, using: 0 TPU cores\n",
"HPU available: False, using: 0 HPUs\n",
"/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n"
]
}
],
"source": [
"trainer = pl.Trainer(logger=False, enable_progress_bar=True, accelerator=\"cpu\", devices=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make uncertainty estimation"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicting DataLoader 0: 100%|███████████████████| 1/1 [00:00<00:00, 126.93it/s]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" smiles | \n",
" target | \n",
" pred | \n",
" unc | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... | \n",
" 2.06 | \n",
" 2.047474 | \n",
" 1.543233 | \n",
"
\n",
" \n",
" 1 | \n",
" O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... | \n",
" 1.92 | \n",
" 2.047561 | \n",
" 1.534631 | \n",
"
\n",
" \n",
" 2 | \n",
" CNCCCC12CCC(c3ccccc31)c1ccccc12 | \n",
" 0.89 | \n",
" 2.062057 | \n",
" 1.548673 | \n",
"
\n",
" \n",
" 3 | \n",
" Oc1ncnc2scc(-c3ccsc3)c12 | \n",
" 2.25 | \n",
" 2.061813 | \n",
" 1.555989 | \n",
"
\n",
" \n",
" 4 | \n",
" C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... | \n",
" 2.04 | \n",
" 2.038238 | \n",
" 1.532385 | \n",
"
\n",
" \n",
" 5 | \n",
" COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 | \n",
" 3.13 | \n",
" 2.048835 | \n",
" 1.535416 | \n",
"
\n",
" \n",
" 6 | \n",
" O=C(COc1ccccc1)c1ccccc1 | \n",
" 2.87 | \n",
" 2.066844 | \n",
" 1.534430 | \n",
"
\n",
" \n",
" 7 | \n",
" CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 | \n",
" 1.10 | \n",
" 2.053771 | \n",
" 1.550390 | \n",
"
\n",
" \n",
" 8 | \n",
" N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 | \n",
" -0.16 | \n",
" 2.047554 | \n",
" 1.535353 | \n",
"
\n",
" \n",
" 9 | \n",
" COc1cnc(-c2ccccn2)nc1N(C)C | \n",
" 1.90 | \n",
" 2.050501 | \n",
" 1.537318 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" smiles target pred \\\n",
"0 Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... 2.06 2.047474 \n",
"1 O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... 1.92 2.047561 \n",
"2 CNCCCC12CCC(c3ccccc31)c1ccccc12 0.89 2.062057 \n",
"3 Oc1ncnc2scc(-c3ccsc3)c12 2.25 2.061813 \n",
"4 C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... 2.04 2.038238 \n",
"5 COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 3.13 2.048835 \n",
"6 O=C(COc1ccccc1)c1ccccc1 2.87 2.066844 \n",
"7 CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 1.10 2.053771 \n",
"8 N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 -0.16 2.047554 \n",
"9 COc1cnc(-c2ccccn2)nc1N(C)C 1.90 2.050501 \n",
"\n",
" unc \n",
"0 1.543233 \n",
"1 1.534631 \n",
"2 1.548673 \n",
"3 1.555989 \n",
"4 1.532385 \n",
"5 1.535416 \n",
"6 1.534430 \n",
"7 1.550390 \n",
"8 1.535353 \n",
"9 1.537318 "
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"test_predss, test_uncss = unc_estimator(test_loader, models, trainer)\n",
"test_preds = test_predss.mean(0)\n",
"test_uncs = test_uncss.mean(0)\n",
"\n",
"df_test = pd.DataFrame(\n",
" {\n",
" \"smiles\": test_dset.smiles,\n",
" \"target\": test_dset.Y.reshape(-1),\n",
" \"pred\": test_preds.reshape(-1),\n",
" \"unc\": test_uncs.reshape(-1),\n",
" }\n",
")\n",
"\n",
"df_test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Apply uncertainty calibration"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Predicting DataLoader 0: 100%|███████████████████| 1/1 [00:00<00:00, 228.26it/s]\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" smiles | \n",
" target | \n",
" pred | \n",
" unc | \n",
" cal_unc | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... | \n",
" 2.06 | \n",
" 2.047474 | \n",
" 1.543233 | \n",
" 1.691122 | \n",
"
\n",
" \n",
" 1 | \n",
" O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... | \n",
" 1.92 | \n",
" 2.047561 | \n",
" 1.534631 | \n",
" 1.681696 | \n",
"
\n",
" \n",
" 2 | \n",
" CNCCCC12CCC(c3ccccc31)c1ccccc12 | \n",
" 0.89 | \n",
" 2.062057 | \n",
" 1.548673 | \n",
" 1.697084 | \n",
"
\n",
" \n",
" 3 | \n",
" Oc1ncnc2scc(-c3ccsc3)c12 | \n",
" 2.25 | \n",
" 2.061813 | \n",
" 1.555989 | \n",
" 1.705101 | \n",
"
\n",
" \n",
" 4 | \n",
" C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... | \n",
" 2.04 | \n",
" 2.038238 | \n",
" 1.532385 | \n",
" 1.679235 | \n",
"
\n",
" \n",
" 5 | \n",
" COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 | \n",
" 3.13 | \n",
" 2.048835 | \n",
" 1.535416 | \n",
" 1.682556 | \n",
"
\n",
" \n",
" 6 | \n",
" O=C(COc1ccccc1)c1ccccc1 | \n",
" 2.87 | \n",
" 2.066844 | \n",
" 1.534430 | \n",
" 1.681475 | \n",
"
\n",
" \n",
" 7 | \n",
" CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 | \n",
" 1.10 | \n",
" 2.053771 | \n",
" 1.550390 | \n",
" 1.698965 | \n",
"
\n",
" \n",
" 8 | \n",
" N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 | \n",
" -0.16 | \n",
" 2.047554 | \n",
" 1.535353 | \n",
" 1.682488 | \n",
"
\n",
" \n",
" 9 | \n",
" COc1cnc(-c2ccccn2)nc1N(C)C | \n",
" 1.90 | \n",
" 2.050501 | \n",
" 1.537318 | \n",
" 1.684641 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" smiles target pred \\\n",
"0 Cc1ccc(NC(=O)c2cscn2)cc1-n1cnc2ccc(N3CCN(C)CC3... 2.06 2.047474 \n",
"1 O=C(Nc1nnc(C(=O)Nc2ccc(N3CCOCC3)cc2)o1)c1ccc(C... 1.92 2.047561 \n",
"2 CNCCCC12CCC(c3ccccc31)c1ccccc12 0.89 2.062057 \n",
"3 Oc1ncnc2scc(-c3ccsc3)c12 2.25 2.061813 \n",
"4 C=CC(=O)Nc1cccc(CN2C(=O)N(c3c(Cl)c(OC)cc(OC)c3... 2.04 2.038238 \n",
"5 COc1cc2ncnc(Nc3ccc(F)c(Cl)c3)c2cc1OCCCN1CCCC1 3.13 2.048835 \n",
"6 O=C(COc1ccccc1)c1ccccc1 2.87 2.066844 \n",
"7 CC(C)c1ccc2oc3nc(N)c(C(=O)O)cc3c(=O)c2c1 1.10 2.053771 \n",
"8 N#Cc1ccc(F)c(-c2cc(C(F)(F)F)ccc2OCC(=O)O)c1 -0.16 2.047554 \n",
"9 COc1cnc(-c2ccccn2)nc1N(C)C 1.90 2.050501 \n",
"\n",
" unc cal_unc \n",
"0 1.543233 1.691122 \n",
"1 1.534631 1.681696 \n",
"2 1.548673 1.697084 \n",
"3 1.555989 1.705101 \n",
"4 1.532385 1.679235 \n",
"5 1.535416 1.682556 \n",
"6 1.534430 1.681475 \n",
"7 1.550390 1.698965 \n",
"8 1.535353 1.682488 \n",
"9 1.537318 1.684641 "
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cal_predss, cal_uncss = unc_estimator(cal_loader, models, trainer)\n",
"average_cal_preds = cal_predss.mean(0)\n",
"average_cal_uncs = cal_uncss.mean(0)\n",
"cal_targets = cal_dset.Y\n",
"cal_mask = torch.from_numpy(np.isfinite(cal_targets))\n",
"cal_targets = np.nan_to_num(cal_targets, nan=0.0)\n",
"cal_targets = torch.from_numpy(cal_targets)\n",
"unc_calibrator.fit(average_cal_preds, average_cal_uncs, cal_targets, cal_mask)\n",
"\n",
"cal_test_uncs = unc_calibrator.apply(test_uncs)\n",
"df_test[\"cal_unc\"] = cal_test_uncs\n",
"df_test"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Evaluate predicted uncertainty"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"nll-regression: [1.4490190356267003]\n",
"miscalibration_area: [0.15619999170303345]\n",
"ence: [0.6248166925739804]\n",
"spearman: [0.27272725105285645]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: Metric `SpearmanCorrcoef` will save all targets and predictions in the buffer. For large datasets, this may lead to large memory footprint.\n",
" warnings.warn(*args, **kwargs) # noqa: B028\n"
]
}
],
"source": [
"test_targets = test_dset.Y\n",
"test_mask = torch.from_numpy(np.isfinite(test_targets))\n",
"test_targets = np.nan_to_num(test_targets, nan=0.0)\n",
"test_targets = torch.from_numpy(test_targets)\n",
"\n",
"for evaluator in unc_evaluators:\n",
" evaluation = evaluator.evaluate(test_preds, cal_test_uncs, test_targets, test_mask)\n",
" print(f\"{evaluator.alias}: {evaluation.tolist()}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "chemprop",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}