# Transfer Learning / Pretraining
Transfer learning (or pretraining) leverages knowledge from a pre-trained model on a related task to enhance performance on a new task. In Chemprop, we can use pre-trained model checkpoints to initialize a new model and freeze components of the new model during training, as demonstrated in this notebook.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/transfer_learning.ipynb)

In [None]:
# Install chemprop from GitHub if running in Google Colab
import os

if os.getenv("COLAB_RELEASE_TAG"):
    try:
        import chemprop
    except ImportError:
        !git clone https://github.com/chemprop/chemprop.git
        %cd chemprop
        !pip install .
        %cd examples

# Import packages

In [1]:
import pandas as pd
from pathlib import Path

from lightning import pytorch as pl
from sklearn.preprocessing import StandardScaler
import torch

from chemprop import data, featurizers, models

# Change data inputs here

In [2]:
chemprop_dir = Path.cwd().parent
input_path = chemprop_dir / "tests" / "data" / "regression" / "mol" / "mol.csv" # path to your data .csv file
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
smiles_column = 'smiles' # name of the column containing SMILES strings
target_columns = ['lipo'] # list of names of the columns containing targets

## Load data

In [3]:
df_input = pd.read_csv(input_path)
df_input

Unnamed: 0,smiles,lipo
0,Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14,3.54
1,COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)...,-1.18
2,COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl,3.69
3,OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(C...,3.37
4,Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)N...,3.10
...,...,...
95,CC(C)N(CCCNC(=O)Nc1ccc(cc1)C(C)(C)C)C[C@H]2O[C...,2.20
96,CCN(CC)CCCCNc1ncc2CN(C(=O)N(Cc3cccc(NC(=O)C=C)...,2.04
97,CCSc1c(Cc2ccccc2C(F)(F)F)sc3N(CC(C)C)C(=O)N(C)...,4.49
98,COc1ccc(Cc2c(N)n[nH]c2N)cc1,0.20


## Get SMILES and targets

In [4]:
smis = df_input.loc[:, smiles_column].values
ys = df_input.loc[:, target_columns].values

In [5]:
smis[:5] # show first 5 SMILES strings

array(['Cn1c(CN2CCN(CC2)c3ccc(Cl)cc3)nc4ccccc14',
       'COc1cc(OC)c(cc1NC(=O)CSCC(=O)O)S(=O)(=O)N2C(C)CCc3ccccc23',
       'COC(=O)[C@@H](N1CCc2sccc2C1)c3ccccc3Cl',
       'OC[C@H](O)CN1C(=O)C(Cc2ccccc12)NC(=O)c3cc4cc(Cl)sc4[nH]3',
       'Cc1cccc(C[C@H](NC(=O)c2cc(nn2C)C(C)(C)C)C(=O)NCC#N)c1'],
      dtype=object)

In [6]:
ys[:5] # show first 5 targets

array([[ 3.54],
       [-1.18],
       [ 3.69],
       [ 3.37],
       [ 3.1 ]])

## Get molecule datapoints

In [7]:
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

## Perform data splitting for training, validation, and testing

In [8]:
# available split types
list(data.SplitType.keys())

['SCAFFOLD_BALANCED',
 'RANDOM_WITH_REPEATED_SMILES',
 'RANDOM',
 'KENNARD_STONE',
 'KMEANS']

In [9]:
mols = [d.mol for d in all_data]  # RDkit Mol objects are use for structure based splits
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

# Change checkpoint model inputs here
Both message-passing neural networks (MPNNs) and multi-component MPNNs can have their weights initialized from a checkpoint file.

In [10]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol.ckpt" # path to the checkpoint file.
# If the checkpoint file is generated using the training notebook, it will be in the `checkpoints` folder with name similar to `checkpoints/epoch=19-step=180.ckpt`.

In [11]:
mpnn_cls = models.MPNN

In [12]:
mpnn = mpnn_cls.load_from_file(checkpoint_path)
mpnn

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.0, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): GraphTransform(
      (V_transform): Identity()
      (E_transform): Identity()
    )
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Identity()
  (metric

# Scale fine-tuning data with the model's target scaler

If the pre-trained model was a regression model, it probably was trained on a scaled dataset. The scaler is saved as part of the model and used during prediction. For furthur training, we need to scale the fine-tuning data with the same target scaler.

In [13]:
pretraining_scaler = StandardScaler()
pretraining_scaler.mean_ = mpnn.predictor.output_transform.mean.numpy()
pretraining_scaler.scale_ = mpnn.predictor.output_transform.scale.numpy()

## Get MoleculeDataset

In [14]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_dset = data.MoleculeDataset(train_data[0], featurizer)
train_dset.normalize_targets(pretraining_scaler)

val_dset = data.MoleculeDataset(val_data[0], featurizer)
val_dset.normalize_targets(pretraining_scaler)

test_dset = data.MoleculeDataset(test_data[0], featurizer)


## Get DataLoader

In [15]:
train_loader = data.build_dataloader(train_dset, num_workers=num_workers)
val_loader = data.build_dataloader(val_dset, num_workers=num_workers, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=num_workers, shuffle=False)

# Freezing MPNN and FFN layers
Certain layers of a pre-trained model can be kept unchanged during further training on a new task.

## Freezing the MPNN

In [16]:
mpnn.message_passing.apply(lambda module: module.requires_grad_(False))
mpnn.message_passing.eval()
mpnn.bn.apply(lambda module: module.requires_grad_(False))
mpnn.bn.eval()  # Set batch norm layers to eval mode to freeze running mean and running var.

BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

## Freezing FFN layers

In [17]:
frzn_ffn_layers = 1  # the number of consecutive FFN layers to freeze.

In [18]:
for idx in range(frzn_ffn_layers):
    mpnn.predictor.ffn[idx].requires_grad_(False)
    mpnn.predictor.ffn[idx + 1].eval()

# Set up trainer

In [19]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True, # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=20, # number of epochs to train for
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [20]:
trainer.fit(mpnn, train_loader, val_loader)

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /Users/brianli/Documents/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.

  | Name            | Type               | Params | Mode 
---------------------------------------------------------------
0 | message_passing | BondMessagePassing | 227 K  | eval 
1 | agg             | MeanAggregation    | 0      | train
2 | bn              | BatchNorm1d        | 600    | eval 
3 | predictor       | RegressionFFN      | 90.6 K | train
4 | X_d_transform   | Identity           | 0   

Sanity Checking DataLoader 0:   0%|                       | 0/1 [00:00<?, ?it/s]

/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.
  M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_(


Epoch 0: 100%|█████████████| 2/2 [00:00<00:00,  5.45it/s, train_loss_step=0.871]
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 48.68it/s][A
Epoch 1: 100%|█| 2/2 [00:00<00:00, 18.36it/s, train_loss_step=1.030, val_loss=0.[A
Validation: |                                             | 0/? [00:00<?, ?it/s][A
Validation:   0%|                                         | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|                            | 0/1 [00:00<?, ?it/s][A
Validation DataLoader 0: 100%|████████████████████| 1/1 [00:00<00:00, 19.34it/s][A
Epoch 2: 100%|█| 2/2 [00:00<00:00, 21.85it/s, train_loss_step=1.070, val_loss=0.[A
Validation: |                                             | 0/? [00:00<?, ?it/s

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|█| 2/2 [00:00<00:00,  9.91it/s, train_loss_step=1.060, val_loss=0


# Test results

In [21]:
results = trainer.test(mpnn, test_loader)


/opt/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|███████████████████████| 1/1 [00:00<00:00, 21.60it/s]


# Transfer learning with multicomponenent models
Multi-component MPNN models have individual MPNN blocks for each molecule it parses in one input. These MPNN modules can be independently frozen for transfer learning.

## Change data inputs here

In [22]:
chemprop_dir = Path.cwd().parent
checkpoint_path = chemprop_dir / "tests" / "data" / "example_model_v2_regression_mol+mol.ckpt"  # path to the checkpoint file. 

## Change checkpoint model inputs here

In [23]:
mpnn_cls = models.MulticomponentMPNN
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
mcmpnn

MulticomponentMPNN(
  (message_passing): MulticomponentMessagePassing(
    (blocks): ModuleList(
      (0-1): 2 x BondMessagePassing(
        (W_i): Linear(in_features=86, out_features=300, bias=False)
        (W_h): Linear(in_features=300, out_features=300, bias=False)
        (W_o): Linear(in_features=372, out_features=300, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (tau): ReLU()
        (V_d_transform): Identity()
        (graph_transform): GraphTransform(
          (V_transform): Identity()
          (E_transform): Identity()
        )
      )
    )
  )
  (agg): MeanAggregation()
  (bn): BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=600, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
  

In [24]:
blocks_to_freeze = [0, 1]  # a list of indices of the individual MPNN blocks to freeze before training.

In [25]:
mcmpnn = mpnn_cls.load_from_checkpoint(checkpoint_path)
for i in blocks_to_freeze:
    mp_block = mcmpnn.message_passing.blocks[i]
    mp_block.apply(lambda module: module.requires_grad_(False))
    mp_block.eval()
mcmpnn.bn.apply(lambda module: module.requires_grad_(False))
mcmpnn.bn.eval()

BatchNorm1d(600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)