diff --git a/chemprop-updated/chemprop/__init__.py b/chemprop-updated/chemprop/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c2a05683e98de16fe555275a8f1430a97a9008e --- /dev/null +++ b/chemprop-updated/chemprop/__init__.py @@ -0,0 +1,5 @@ +from . import data, exceptions, featurizers, models, nn, schedulers, utils + +__all__ = ["data", "featurizers", "models", "nn", "utils", "exceptions", "schedulers"] + +__version__ = "2.1.2" diff --git a/chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9b7095c4cb5147e6fb1867912ea804331ba5bb5 Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e72640d8427d96850ec6362d5443893265b037c Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a79288e0006c7aa20463e3569efdda612645f38f Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0efad6a7d927f8adfe8268cc54e3c2fad10b12e3 Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a80289e1131169088a082fda6cb6fd723970ef9d Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4ff850a4d89900133c1cddf13a4ae1abfd861e0 Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52f37ea7d193887ee7e45c97f61e997b896465cf Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5edd6b47a163590b196d1b1e4dd9faa05f2175eb Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..556cc7ee61a15e08f0ec64256af87fdbce6d3e33 Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ed0056d83f64e771778cd13fd8ad8a18f8007c6 Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3647d20a264473bf844ab3f1bed7b7a073c710f Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16ce436997ebd26f9d522ff691aa6b39708d705e Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc b/chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5b9d9fd85df60238ed748527d1a06c5290a2231 Binary files /dev/null and b/chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/cli/common.py b/chemprop-updated/chemprop/cli/common.py new file mode 100644 index 0000000000000000000000000000000000000000..798627387c9a490444b050d4d0a1db2aa04a7ce8 --- /dev/null +++ b/chemprop-updated/chemprop/cli/common.py @@ -0,0 +1,216 @@ +from argparse import ArgumentError, ArgumentParser, Namespace +import logging +from pathlib import Path + +from chemprop.cli.utils import LookupAction +from chemprop.cli.utils.args import uppercase +from chemprop.featurizers import AtomFeatureMode, MoleculeFeaturizerRegistry, RxnMode + +logger = logging.getLogger(__name__) + + +def add_common_args(parser: ArgumentParser) -> ArgumentParser: + data_args = parser.add_argument_group("Shared input data args") + data_args.add_argument( + "-s", + "--smiles-columns", + nargs="+", + help="Column names in the input CSV containing SMILES strings (uses the 0th column by default)", + ) + data_args.add_argument( + "-r", + "--reaction-columns", + nargs="+", + help="Column names in the input CSV containing reaction SMILES in the format ``REACTANT>AGENT>PRODUCT``, where 'AGENT' is optional", + ) + data_args.add_argument( + "--no-header-row", + action="store_true", + help="Turn off using the first row in the input CSV as column names", + ) + + dataloader_args = parser.add_argument_group("Dataloader args") + dataloader_args.add_argument( + "-n", + "--num-workers", + type=int, + default=0, + help="""Number of workers for parallel data loading where 0 means sequential +(Warning: setting ``num_workers`` to a value greater than 0 can cause hangs on Windows and MacOS)""", + ) + dataloader_args.add_argument("-b", "--batch-size", type=int, default=64, help="Batch size") + + parser.add_argument( + "--accelerator", default="auto", help="Passed directly to the lightning ``Trainer()``" + ) + parser.add_argument( + "--devices", + default="auto", + help="Passed directly to the lightning ``Trainer()`` (must be a single string of comma separated devices, e.g. '1, 2' if specifying multiple devices)", + ) + + featurization_args = parser.add_argument_group("Featurization args") + featurization_args.add_argument( + "--rxn-mode", + "--reaction-mode", + type=uppercase, + default="REAC_DIFF", + choices=list(RxnMode.keys()), + help="""Choices for construction of atom and bond features for reactions (case insensitive): + +- ``REAC_PROD``: concatenates the reactants feature with the products feature +- ``REAC_DIFF``: concatenates the reactants feature with the difference in features between reactants and products (Default) +- ``PROD_DIFF``: concatenates the products feature with the difference in features between reactants and products +- ``REAC_PROD_BALANCE``: concatenates the reactants feature with the products feature, balances imbalanced reactions +- ``REAC_DIFF_BALANCE``: concatenates the reactants feature with the difference in features between reactants and products, balances imbalanced reactions +- ``PROD_DIFF_BALANCE``: concatenates the products feature with the difference in features between reactants and products, balances imbalanced reactions""", + ) + # TODO: Update documenation for multi_hot_atom_featurizer_mode + featurization_args.add_argument( + "--multi-hot-atom-featurizer-mode", + type=uppercase, + default="V2", + choices=list(AtomFeatureMode.keys()), + help="""Choices for multi-hot atom featurization scheme. This will affect both non-reaction and reaction feturization (case insensitive): + +- ``V1``: Corresponds to the original configuration employed in the Chemprop V1 +- ``V2``: Tailored for a broad range of molecules, this configuration encompasses all elements in the first four rows of the periodic table, along with iodine. It is the default in Chemprop V2. +- ``ORGANIC``: This configuration is designed specifically for use with organic molecules for drug research and development and includes a subset of elements most common in organic chemistry, including H, B, C, N, O, F, Si, P, S, Cl, Br, and I. +- ``RIGR``: Modified V2 (default) featurizer using only the resonance-invariant atom and bond features.""", + ) + featurization_args.add_argument( + "--keep-h", + action="store_true", + help="Whether hydrogens explicitly specified in input should be kept in the mol graph", + ) + featurization_args.add_argument( + "--add-h", action="store_true", help="Whether hydrogens should be added to the mol graph" + ) + data_args.add_argument( + "--ignore-chirality", + action="store_true", + help="Ignore chirality information in the input SMILES", + ) + featurization_args.add_argument( + "--molecule-featurizers", + "--features-generators", + nargs="+", + action=LookupAction(MoleculeFeaturizerRegistry), + help="Method(s) of generating molecule features to use as extra descriptors", + ) + # TODO: add in v2.1 to deprecate features-generators and then remove in v2.2 + # featurization_args.add_argument( + # "--features-generators", nargs="+", help="Renamed to `--molecule-featurizers`." + # ) + featurization_args.add_argument( + "--descriptors-path", + type=Path, + help="Path to extra descriptors to concatenate to learned representation", + ) + # TODO: Add in v2.1 + # featurization_args.add_argument( + # "--phase-features-path", + # help="Path to features used to indicate the phase of the data in one-hot vector form. Used in spectra datatype.", + # ) + featurization_args.add_argument( + "--no-descriptor-scaling", action="store_true", help="Turn off extra descriptor scaling" + ) + featurization_args.add_argument( + "--no-atom-feature-scaling", action="store_true", help="Turn off extra atom feature scaling" + ) + featurization_args.add_argument( + "--no-atom-descriptor-scaling", + action="store_true", + help="Turn off extra atom descriptor scaling", + ) + featurization_args.add_argument( + "--no-bond-feature-scaling", action="store_true", help="Turn off extra bond feature scaling" + ) + featurization_args.add_argument( + "--atom-features-path", + nargs="+", + action="append", + help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom features to supply before message passing (e.g., ``--atom-features-path 0 /path/to/features_0.npz``) indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-features-path [...] --atom-features-path [...]``).", + ) + featurization_args.add_argument( + "--atom-descriptors-path", + nargs="+", + action="append", + help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom descriptors to supply after message passing (e.g., ``--atom-descriptors-path 0 /path/to/descriptors_0.npz`` indicates that the descriptors at the given path should be supplied to the 0-th component. To supply additional descriptors for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-descriptors-path [...] --atom-descriptors-path [...]``).", + ) + featurization_args.add_argument( + "--bond-features-path", + nargs="+", + action="append", + help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional bond features to supply before message passing (e.g., ``--bond-features-path 0 /path/to/features_0.npz`` indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--bond-features-path [...] --bond-features-path [...]``).", + ) + # TODO: Add in v2.2 + # parser.add_argument( + # "--constraints-path", + # help="Path to constraints applied to atomic/bond properties prediction.", + # ) + + return parser + + +def process_common_args(args: Namespace) -> Namespace: + # TODO: add in v2.1 to deprecate features-generators and then remove in v2.2 + # if args.features_generators is not None: + # raise ArgumentError( + # argument=None, + # message="`--features-generators` has been renamed to `--molecule-featurizers`.", + # ) + + for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]: + inds_paths = getattr(args, key) + + if not inds_paths: + continue + + ind_path_dict = {} + + for ind_path in inds_paths: + if len(ind_path) > 2: + raise ArgumentError( + argument=None, + message="Too many arguments were given for atom features/descriptors or bond features. It should be either a two-tuple of molecule index and a path, or a single path (assumed to be the 0-th molecule).", + ) + + if len(ind_path) == 1: + ind = 0 + path = ind_path[0] + else: + ind, path = ind_path + + if ind_path_dict.get(int(ind), None): + raise ArgumentError( + argument=None, + message=f"Duplicate atom features/descriptors or bond features given for molecule index {ind}", + ) + + ind_path_dict[int(ind)] = Path(path) + + setattr(args, key, ind_path_dict) + + return args + + +def validate_common_args(args): + pass + + +def find_models(model_paths: list[Path]): + collected_model_paths = [] + + for model_path in model_paths: + if model_path.suffix in [".ckpt", ".pt"]: + collected_model_paths.append(model_path) + elif model_path.is_dir(): + collected_model_paths.extend(list(model_path.rglob("*.pt"))) + else: + raise ArgumentError( + argument=None, + message=f"Expected a .ckpt or .pt file, or a directory. Got {model_path}", + ) + + return collected_model_paths diff --git a/chemprop-updated/chemprop/cli/conf.py b/chemprop-updated/chemprop/cli/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..be7701c52cc6509817a7cb9d4223ea33083f422b --- /dev/null +++ b/chemprop-updated/chemprop/cli/conf.py @@ -0,0 +1,9 @@ +from datetime import datetime +import logging +import os +from pathlib import Path + +LOG_DIR = Path(os.getenv("CHEMPROP_LOG_DIR", "chemprop_logs")) +LOG_LEVELS = {0: logging.INFO, 1: logging.DEBUG, -1: logging.WARNING, -2: logging.ERROR} +NOW = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") +CHEMPROP_TRAIN_DIR = Path(os.getenv("CHEMPROP_TRAIN_DIR", "chemprop_training")) diff --git a/chemprop-updated/chemprop/cli/convert.py b/chemprop-updated/chemprop/cli/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..e75795e9cb19985d49108d014b447d29209340fe --- /dev/null +++ b/chemprop-updated/chemprop/cli/convert.py @@ -0,0 +1,55 @@ +from argparse import ArgumentError, ArgumentParser, Namespace +import logging +from pathlib import Path +import sys + +from chemprop.cli.utils import Subcommand +from chemprop.utils.v1_to_v2 import convert_model_file_v1_to_v2 + +logger = logging.getLogger(__name__) + + +class ConvertSubcommand(Subcommand): + COMMAND = "convert" + HELP = "Convert a v1 model checkpoint (.pt) to a v2 model checkpoint (.pt)." + + @classmethod + def add_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "-i", + "--input-path", + required=True, + type=Path, + help="Path to a v1 model .pt checkpoint file", + ) + parser.add_argument( + "-o", + "--output-path", + type=Path, + help="Path to which the converted model will be saved (``CURRENT_DIRECTORY/STEM_OF_INPUT_v2.pt`` by default)", + ) + return parser + + @classmethod + def func(cls, args: Namespace): + if args.output_path is None: + args.output_path = Path(args.input_path.stem + "_v2.pt") + if args.output_path.suffix != ".pt": + raise ArgumentError( + argument=None, message=f"Output must be a `.pt` file. Got {args.output_path}" + ) + + logger.info( + f"Converting v1 model checkpoint '{args.input_path}' to v2 model checkpoint '{args.output_path}'..." + ) + convert_model_file_v1_to_v2(args.input_path, args.output_path) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser = ConvertSubcommand.add_args(parser) + + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) + + args = parser.parse_args() + ConvertSubcommand.func(args) diff --git a/chemprop-updated/chemprop/cli/fingerprint.py b/chemprop-updated/chemprop/cli/fingerprint.py new file mode 100644 index 0000000000000000000000000000000000000000..c136730d51deb46c368b8f29ca996614e232acba --- /dev/null +++ b/chemprop-updated/chemprop/cli/fingerprint.py @@ -0,0 +1,185 @@ +from argparse import ArgumentError, ArgumentParser, Namespace +import logging +from pathlib import Path +import sys + +import numpy as np +import pandas as pd +import torch + +from chemprop import data +from chemprop.cli.common import add_common_args, process_common_args, validate_common_args +from chemprop.cli.predict import find_models +from chemprop.cli.utils import Subcommand, build_data_from_files, make_dataset +from chemprop.models import load_model +from chemprop.nn.metrics import LossFunctionRegistry + +logger = logging.getLogger(__name__) + + +class FingerprintSubcommand(Subcommand): + COMMAND = "fingerprint" + HELP = "Use a pretrained chemprop model to calculate learned representations." + + @classmethod + def add_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser = add_common_args(parser) + parser.add_argument( + "-i", + "--test-path", + required=True, + type=Path, + help="Path to an input CSV file containing SMILES", + ) + parser.add_argument( + "-o", + "--output", + "--preds-path", + type=Path, + help="Specify the path where predictions will be saved. If the file extension is .npz, they will be saved as a npz file. Otherwise, the predictions will be saved as a CSV. The index of the model will be appended to the filename's stem. By default, predictions will be saved to the same location as ``--test-path`` with '_fps' appended (e.g., 'PATH/TO/TEST_PATH_fps_0.csv').", + ) + parser.add_argument( + "--model-paths", + "--model-path", + required=True, + type=Path, + nargs="+", + help="Specify location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, chemprop will recursively search and predict on all found (.pt) models.", + ) + parser.add_argument( + "--ffn-block-index", + required=True, + type=int, + default=-1, + help="The index indicates which linear layer returns the encoding in the FFN. An index of 0 denotes the post-aggregation representation through a 0-layer MLP, while an index of 1 represents the output from the first linear layer in the FFN, and so forth.", + ) + + return parser + + @classmethod + def func(cls, args: Namespace): + args = process_common_args(args) + validate_common_args(args) + args = process_fingerprint_args(args) + main(args) + + +def process_fingerprint_args(args: Namespace) -> Namespace: + if args.test_path.suffix not in [".csv"]: + raise ArgumentError( + argument=None, message=f"Input data must be a CSV file. Got {args.test_path}" + ) + if args.output is None: + args.output = args.test_path.parent / (args.test_path.stem + "_fps.csv") + if args.output.suffix not in [".csv", ".npz"]: + raise ArgumentError( + argument=None, message=f"Output must be a CSV or NPZ file. Got '{args.output}'." + ) + return args + + +def make_fingerprint_for_model( + args: Namespace, model_path: Path, multicomponent: bool, output_path: Path +): + model = load_model(model_path, multicomponent) + model.eval() + + bounded = any( + isinstance(model.criterion, LossFunctionRegistry[loss_function]) + for loss_function in LossFunctionRegistry.keys() + if "bounded" in loss_function + ) + + format_kwargs = dict( + no_header_row=args.no_header_row, + smiles_cols=args.smiles_columns, + rxn_cols=args.reaction_columns, + target_cols=[], + ignore_cols=None, + splits_col=None, + weight_col=None, + bounded=bounded, + ) + + featurization_kwargs = dict( + molecule_featurizers=args.molecule_featurizers, + keep_h=args.keep_h, + add_h=args.add_h, + ignore_chirality=args.ignore_chirality, + ) + + test_data = build_data_from_files( + args.test_path, + **format_kwargs, + p_descriptors=args.descriptors_path, + p_atom_feats=args.atom_features_path, + p_bond_feats=args.bond_features_path, + p_atom_descs=args.atom_descriptors_path, + **featurization_kwargs, + ) + logger.info(f"test size: {len(test_data[0])}") + test_dsets = [ + make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in test_data + ] + + if multicomponent: + test_dset = data.MulticomponentDataset(test_dsets) + else: + test_dset = test_dsets[0] + + test_loader = data.build_dataloader(test_dset, args.batch_size, args.num_workers, shuffle=False) + + logger.info(model) + + with torch.no_grad(): + if multicomponent: + encodings = [ + model.encoding(batch.bmgs, batch.V_ds, batch.X_d, args.ffn_block_index) + for batch in test_loader + ] + else: + encodings = [ + model.encoding(batch.bmg, batch.V_d, batch.X_d, args.ffn_block_index) + for batch in test_loader + ] + H = torch.cat(encodings, 0).numpy() + + if output_path.suffix in [".npz"]: + np.savez(output_path, H=H) + elif output_path.suffix == ".csv": + fingerprint_columns = [f"fp_{i}" for i in range(H.shape[1])] + df_fingerprints = pd.DataFrame(H, columns=fingerprint_columns) + df_fingerprints.to_csv(output_path, index=False) + else: + raise ArgumentError( + argument=None, message=f"Output must be a CSV or npz file. Got {args.output}." + ) + logger.info(f"Fingerprints saved to '{output_path}'") + + +def main(args): + match (args.smiles_columns, args.reaction_columns): + case [None, None]: + n_components = 1 + case [_, None]: + n_components = len(args.smiles_columns) + case [None, _]: + n_components = len(args.reaction_columns) + case _: + n_components = len(args.smiles_columns) + len(args.reaction_columns) + + multicomponent = n_components > 1 + + for i, model_path in enumerate(find_models(args.model_paths)): + logger.info(f"Fingerprints with model {i} at '{model_path}'") + output_path = args.output.parent / f"{args.output.stem}_{i}{args.output.suffix}" + make_fingerprint_for_model(args, model_path, multicomponent, output_path) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser = FingerprintSubcommand.add_args(parser) + + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) + args = parser.parse_args() + args = FingerprintSubcommand.func(args) diff --git a/chemprop-updated/chemprop/cli/hpopt.py b/chemprop-updated/chemprop/cli/hpopt.py new file mode 100644 index 0000000000000000000000000000000000000000..f205594d9829072a88babf01b09a09618bd8f98c --- /dev/null +++ b/chemprop-updated/chemprop/cli/hpopt.py @@ -0,0 +1,540 @@ +from copy import deepcopy +import logging +from pathlib import Path +import shutil +import sys + +from configargparse import ArgumentParser, Namespace +from lightning import pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping +import numpy as np +import torch + +from chemprop.cli.common import add_common_args, process_common_args, validate_common_args +from chemprop.cli.train import ( + TrainSubcommand, + add_train_args, + build_datasets, + build_model, + build_splits, + normalize_inputs, + process_train_args, + save_config, + validate_train_args, +) +from chemprop.cli.utils.command import Subcommand +from chemprop.data import build_dataloader +from chemprop.nn import AggregationRegistry, MetricRegistry +from chemprop.nn.transforms import UnscaleTransform +from chemprop.nn.utils import Activation + +NO_RAY = False +DEFAULT_SEARCH_SPACE = { + "activation": None, + "aggregation": None, + "aggregation_norm": None, + "batch_size": None, + "depth": None, + "dropout": None, + "ffn_hidden_dim": None, + "ffn_num_layers": None, + "final_lr_ratio": None, + "message_hidden_dim": None, + "init_lr_ratio": None, + "max_lr": None, + "warmup_epochs": None, +} + +try: + import ray + from ray import tune + from ray.train import CheckpointConfig, RunConfig, ScalingConfig + from ray.train.lightning import ( + RayDDPStrategy, + RayLightningEnvironment, + RayTrainReportCallback, + prepare_trainer, + ) + from ray.train.torch import TorchTrainer + from ray.tune.schedulers import ASHAScheduler, FIFOScheduler + + DEFAULT_SEARCH_SPACE = { + "activation": tune.choice(categories=list(Activation.keys())), + "aggregation": tune.choice(categories=list(AggregationRegistry.keys())), + "aggregation_norm": tune.quniform(lower=1, upper=200, q=1), + "batch_size": tune.choice([16, 32, 64, 128, 256]), + "depth": tune.qrandint(lower=2, upper=6, q=1), + "dropout": tune.choice([0.0] * 8 + list(np.arange(0.05, 0.45, 0.05))), + "ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100), + "ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1), + "final_lr_ratio": tune.loguniform(lower=1e-2, upper=1), + "message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100), + "init_lr_ratio": tune.loguniform(lower=1e-2, upper=1), + "max_lr": tune.loguniform(lower=1e-4, upper=1e-2), + "warmup_epochs": None, + } +except ImportError: + NO_RAY = True + +NO_HYPEROPT = False +try: + from ray.tune.search.hyperopt import HyperOptSearch +except ImportError: + NO_HYPEROPT = True + +NO_OPTUNA = False +try: + from ray.tune.search.optuna import OptunaSearch +except ImportError: + NO_OPTUNA = True + + +logger = logging.getLogger(__name__) + +SEARCH_SPACE = DEFAULT_SEARCH_SPACE + +SEARCH_PARAM_KEYWORDS_MAP = { + "basic": ["depth", "ffn_num_layers", "dropout", "ffn_hidden_dim", "message_hidden_dim"], + "learning_rate": ["max_lr", "init_lr_ratio", "final_lr_ratio", "warmup_epochs"], + "all": list(DEFAULT_SEARCH_SPACE.keys()), + "init_lr": ["init_lr_ratio"], + "final_lr": ["final_lr_ratio"], +} + + +class HpoptSubcommand(Subcommand): + COMMAND = "hpopt" + HELP = "Perform hyperparameter optimization on the given task." + + @classmethod + def add_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser = add_common_args(parser) + parser = add_train_args(parser) + return add_hpopt_args(parser) + + @classmethod + def func(cls, args: Namespace): + args = process_common_args(args) + args = process_train_args(args) + args = process_hpopt_args(args) + validate_common_args(args) + validate_train_args(args) + main(args) + + +def add_hpopt_args(parser: ArgumentParser) -> ArgumentParser: + hpopt_args = parser.add_argument_group("Chemprop hyperparameter optimization arguments") + + hpopt_args.add_argument( + "--search-parameter-keywords", + type=str, + nargs="+", + default=["basic"], + help=f"""The model parameters over which to search for an optimal hyperparameter configuration. Some options are bundles of parameters or otherwise special parameter operations. Special keywords include: + - ``basic``: Default set of hyperparameters for search (depth, ffn_num_layers, dropout, message_hidden_dim, and ffn_hidden_dim) + - ``learning_rate``: Search for max_lr, init_lr_ratio, final_lr_ratio, and warmup_epochs. The search for init_lr and final_lr values are defined as fractions of the max_lr value. The search for warmup_epochs is as a fraction of the total epochs used. + - ``all``: Include search for all 13 individual keyword options (including: activation, aggregation, aggregation_norm, and batch_size which aren't included in the other two keywords). + Individual supported parameters: + {list(DEFAULT_SEARCH_SPACE.keys())} + """, + ) + + hpopt_args.add_argument( + "--hpopt-save-dir", + type=Path, + help="Directory to save the hyperparameter optimization results", + ) + + raytune_args = parser.add_argument_group("Ray Tune arguments") + + raytune_args.add_argument( + "--raytune-num-samples", + type=int, + default=10, + help="Passed directly to Ray Tune ``TuneConfig`` to control number of trials to run", + ) + + raytune_args.add_argument( + "--raytune-search-algorithm", + choices=["random", "hyperopt", "optuna"], + default="hyperopt", + help="Passed to Ray Tune ``TuneConfig`` to control search algorithm", + ) + + raytune_args.add_argument( + "--raytune-trial-scheduler", + choices=["FIFO", "AsyncHyperBand"], + default="FIFO", + help="Passed to Ray Tune ``TuneConfig`` to control trial scheduler", + ) + + raytune_args.add_argument( + "--raytune-num-workers", + type=int, + default=1, + help="Passed directly to Ray Tune ``ScalingConfig`` to control number of workers to use", + ) + + raytune_args.add_argument( + "--raytune-use-gpu", + action="store_true", + help="Passed directly to Ray Tune ``ScalingConfig`` to control whether to use GPUs", + ) + + raytune_args.add_argument( + "--raytune-num-checkpoints-to-keep", + type=int, + default=1, + help="Passed directly to Ray Tune ``CheckpointConfig`` to control number of checkpoints to keep", + ) + + raytune_args.add_argument( + "--raytune-grace-period", + type=int, + default=10, + help="Passed directly to Ray Tune ``ASHAScheduler`` to control grace period", + ) + + raytune_args.add_argument( + "--raytune-reduction-factor", + type=int, + default=2, + help="Passed directly to Ray Tune ``ASHAScheduler`` to control reduction factor", + ) + + raytune_args.add_argument( + "--raytune-temp-dir", help="Passed directly to Ray Tune init to control temporary directory" + ) + + raytune_args.add_argument( + "--raytune-num-cpus", + type=int, + help="Passed directly to Ray Tune init to control number of CPUs to use", + ) + + raytune_args.add_argument( + "--raytune-num-gpus", + type=int, + help="Passed directly to Ray Tune init to control number of GPUs to use", + ) + + raytune_args.add_argument( + "--raytune-max-concurrent-trials", + type=int, + help="Passed directly to Ray Tune TuneConfig to control maximum concurrent trials", + ) + + hyperopt_args = parser.add_argument_group("Hyperopt arguments") + + hyperopt_args.add_argument( + "--hyperopt-n-initial-points", + type=int, + help="Passed directly to ``HyperOptSearch`` to control number of initial points to sample", + ) + + hyperopt_args.add_argument( + "--hyperopt-random-state-seed", + type=int, + default=None, + help="Passed directly to ``HyperOptSearch`` to control random state seed", + ) + + return parser + + +def process_hpopt_args(args: Namespace) -> Namespace: + if args.hpopt_save_dir is None: + args.hpopt_save_dir = Path(f"chemprop_hpopt/{args.data_path.stem}") + + args.hpopt_save_dir.mkdir(exist_ok=True, parents=True) + + search_parameters = set() + + available_search_parameters = list(SEARCH_SPACE.keys()) + list(SEARCH_PARAM_KEYWORDS_MAP.keys()) + + for keyword in args.search_parameter_keywords: + if keyword not in available_search_parameters: + raise ValueError( + f"Search parameter keyword: {keyword} not in available options: {available_search_parameters}." + ) + + search_parameters.update( + SEARCH_PARAM_KEYWORDS_MAP[keyword] + if keyword in SEARCH_PARAM_KEYWORDS_MAP + else [keyword] + ) + + args.search_parameter_keywords = list(search_parameters) + + if not args.hyperopt_n_initial_points: + args.hyperopt_n_initial_points = args.raytune_num_samples // 2 + + return args + + +def build_search_space(search_parameters: list[str], train_epochs: int) -> dict: + if "warmup_epochs" in search_parameters and SEARCH_SPACE.get("warmup_epochs", None) is None: + assert ( + train_epochs >= 6 + ), "Training epochs must be at least 6 to perform hyperparameter optimization for warmup_epochs." + SEARCH_SPACE["warmup_epochs"] = tune.qrandint(lower=1, upper=train_epochs // 2, q=1) + + return {param: SEARCH_SPACE[param] for param in search_parameters} + + +def update_args_with_config(args: Namespace, config: dict) -> Namespace: + args = deepcopy(args) + + for key, value in config.items(): + match key: + case "final_lr_ratio": + setattr(args, "final_lr", value * config.get("max_lr", args.max_lr)) + + case "init_lr_ratio": + setattr(args, "init_lr", value * config.get("max_lr", args.max_lr)) + + case _: + assert key in args, f"Key: {key} not found in args." + setattr(args, key, value) + + return args + + +def train_model(config, args, train_dset, val_dset, logger, output_transform, input_transforms): + args = update_args_with_config(args, config) + + train_loader = build_dataloader( + train_dset, args.batch_size, args.num_workers, seed=args.data_seed + ) + val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False) + + seed = args.pytorch_seed if args.pytorch_seed is not None else torch.seed() + + torch.manual_seed(seed) + + model = build_model(args, train_loader.dataset, output_transform, input_transforms) + logger.info(model) + + if args.tracking_metric == "val_loss": + T_tracking_metric = model.criterion.__class__ + else: + T_tracking_metric = MetricRegistry[args.tracking_metric] + args.tracking_metric = "val/" + args.tracking_metric + + monitor_mode = "max" if T_tracking_metric.higher_is_better else "min" + logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'") + + patience = args.patience if args.patience is not None else args.epochs + early_stopping = EarlyStopping(args.tracking_metric, patience=patience, mode=monitor_mode) + + trainer = pl.Trainer( + accelerator=args.accelerator, + devices=args.devices, + max_epochs=args.epochs, + gradient_clip_val=args.grad_clip, + strategy=RayDDPStrategy(), + callbacks=[RayTrainReportCallback(), early_stopping], + plugins=[RayLightningEnvironment()], + deterministic=args.pytorch_seed is not None, + ) + trainer = prepare_trainer(trainer) + trainer.fit(model, train_loader, val_loader) + + +def tune_model( + args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms +): + match args.raytune_trial_scheduler: + case "FIFO": + scheduler = FIFOScheduler() + case "AsyncHyperBand": + scheduler = ASHAScheduler( + max_t=args.epochs, + grace_period=min(args.raytune_grace_period, args.epochs), + reduction_factor=args.raytune_reduction_factor, + ) + case _: + raise ValueError(f"Invalid trial scheduler! got: {args.raytune_trial_scheduler}.") + + resources_per_worker = {} + if args.raytune_num_cpus and args.raytune_max_concurrent_trials: + resources_per_worker["CPU"] = args.raytune_num_cpus / args.raytune_max_concurrent_trials + if args.raytune_num_gpus and args.raytune_max_concurrent_trials: + resources_per_worker["GPU"] = args.raytune_num_gpus / args.raytune_max_concurrent_trials + if not resources_per_worker: + resources_per_worker = None + + if args.raytune_num_gpus: + use_gpu = True + else: + use_gpu = args.raytune_use_gpu + + scaling_config = ScalingConfig( + num_workers=args.raytune_num_workers, + use_gpu=use_gpu, + resources_per_worker=resources_per_worker, + trainer_resources={"CPU": 0}, + ) + + checkpoint_config = CheckpointConfig( + num_to_keep=args.raytune_num_checkpoints_to_keep, + checkpoint_score_attribute=args.tracking_metric, + checkpoint_score_order=monitor_mode, + ) + + run_config = RunConfig( + checkpoint_config=checkpoint_config, + storage_path=args.hpopt_save_dir.absolute() / "ray_results", + ) + + ray_trainer = TorchTrainer( + lambda config: train_model( + config, args, train_dset, val_dset, logger, output_transform, input_transforms + ), + scaling_config=scaling_config, + run_config=run_config, + ) + + match args.raytune_search_algorithm: + case "random": + search_alg = None + case "hyperopt": + if NO_HYPEROPT: + raise ImportError( + "HyperOptSearch requires hyperopt to be installed. Use 'pip install -U hyperopt' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages." + ) + + search_alg = HyperOptSearch( + n_initial_points=args.hyperopt_n_initial_points, + random_state_seed=args.hyperopt_random_state_seed, + ) + case "optuna": + if NO_OPTUNA: + raise ImportError( + "OptunaSearch requires optuna to be installed. Use 'pip install -U optuna' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages." + ) + + search_alg = OptunaSearch() + + tune_config = tune.TuneConfig( + metric=args.tracking_metric, + mode=monitor_mode, + num_samples=args.raytune_num_samples, + scheduler=scheduler, + search_alg=search_alg, + trial_dirname_creator=lambda trial: str(trial.trial_id), + ) + + tuner = tune.Tuner( + ray_trainer, + param_space={ + "train_loop_config": build_search_space(args.search_parameter_keywords, args.epochs) + }, + tune_config=tune_config, + ) + + return tuner.fit() + + +def main(args: Namespace): + if NO_RAY: + raise ImportError( + "Ray Tune requires ray to be installed. If you installed Chemprop from PyPI, run 'pip install -U ray[tune]' to install ray. If you installed from source, use 'pip install -e .[hpopt]' in Chemprop folder to install all hpopt relevant packages." + ) + + if not ray.is_initialized(): + try: + ray.init( + _temp_dir=args.raytune_temp_dir, + num_cpus=args.raytune_num_cpus, + num_gpus=args.raytune_num_gpus, + ) + except OSError as e: + if "AF_UNIX path length cannot exceed 107 bytes" in str(e): + raise OSError( + f"Ray Tune fails due to: {e}. This can sometimes be solved by providing a temporary directory, num_cpus, and num_gpus to Ray Tune via the CLI: --raytune-temp-dir --raytune-num-cpus --raytune-num-gpus ." + ) + else: + raise e + else: + logger.info("Ray is already initialized.") + + format_kwargs = dict( + no_header_row=args.no_header_row, + smiles_cols=args.smiles_columns, + rxn_cols=args.reaction_columns, + target_cols=args.target_columns, + ignore_cols=args.ignore_columns, + splits_col=args.splits_column, + weight_col=args.weight_column, + bounded=args.loss_function is not None and "bounded" in args.loss_function, + ) + + featurization_kwargs = dict( + molecule_featurizers=args.molecule_featurizers, + keep_h=args.keep_h, + add_h=args.add_h, + ignore_chirality=args.ignore_chirality, + ) + + train_data, val_data, test_data = build_splits(args, format_kwargs, featurization_kwargs) + train_dset, val_dset, test_dset = build_datasets(args, train_data[0], val_data[0], test_data[0]) + + input_transforms = normalize_inputs(train_dset, val_dset, args) + + if "regression" in args.task_type: + output_scaler = train_dset.normalize_targets() + val_dset.normalize_targets(output_scaler) + logger.info(f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}") + output_transform = UnscaleTransform.from_standard_scaler(output_scaler) + else: + output_transform = None + + train_loader = build_dataloader( + train_dset, args.batch_size, args.num_workers, seed=args.data_seed + ) + + model = build_model(args, train_loader.dataset, output_transform, input_transforms) + monitor_mode = "max" if model.metrics[0].higher_is_better else "min" + + results = tune_model( + args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms + ) + + best_result = results.get_best_result() + best_config = best_result.config["train_loop_config"] + best_checkpoint_path = Path(best_result.checkpoint.path) / "checkpoint.ckpt" + + best_config_save_path = args.hpopt_save_dir / "best_config.toml" + best_checkpoint_save_path = args.hpopt_save_dir / "best_checkpoint.ckpt" + all_progress_save_path = args.hpopt_save_dir / "all_progress.csv" + + logger.info(f"Best hyperparameters saved to: '{best_config_save_path}'") + + args = update_args_with_config(args, best_config) + + args = TrainSubcommand.parser.parse_known_args(namespace=args)[0] + save_config(TrainSubcommand.parser, args, best_config_save_path) + + logger.info( + f"Best hyperparameter configuration checkpoint saved to '{best_checkpoint_save_path}'" + ) + + shutil.copyfile(best_checkpoint_path, best_checkpoint_save_path) + + logger.info(f"Hyperparameter optimization results saved to '{all_progress_save_path}'") + + result_df = results.get_dataframe() + + result_df.to_csv(all_progress_save_path, index=False) + + ray.shutdown() + + +if __name__ == "__main__": + parser = ArgumentParser() + parser = HpoptSubcommand.add_args(parser) + + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) + args = parser.parse_args() + HpoptSubcommand.func(args) diff --git a/chemprop-updated/chemprop/cli/main.py b/chemprop-updated/chemprop/cli/main.py new file mode 100644 index 0000000000000000000000000000000000000000..56d4f5205a0b7a5a50003a1c3ddba971260784ef --- /dev/null +++ b/chemprop-updated/chemprop/cli/main.py @@ -0,0 +1,85 @@ +import logging +from pathlib import Path +import sys + +from configargparse import ArgumentParser + +from chemprop.cli.conf import LOG_DIR, LOG_LEVELS, NOW +from chemprop.cli.convert import ConvertSubcommand +from chemprop.cli.fingerprint import FingerprintSubcommand +from chemprop.cli.hpopt import HpoptSubcommand +from chemprop.cli.predict import PredictSubcommand +from chemprop.cli.train import TrainSubcommand +from chemprop.cli.utils import pop_attr + +logger = logging.getLogger(__name__) + +SUBCOMMANDS = [ + TrainSubcommand, + PredictSubcommand, + ConvertSubcommand, + FingerprintSubcommand, + HpoptSubcommand, +] + + +def construct_parser(): + parser = ArgumentParser() + subparsers = parser.add_subparsers(title="mode", dest="mode", required=True) + + parent = ArgumentParser(add_help=False) + parent.add_argument( + "--logfile", + "--log", + nargs="?", + const="default", + help=f"Path to which the log file should be written (specifying just the flag alone will automatically log to a file ``{LOG_DIR}/MODE/TIMESTAMP.log`` , where 'MODE' is the CLI mode chosen, e.g., ``{LOG_DIR}/MODE/{NOW}.log``)", + ) + parent.add_argument("-v", action="store_true", help="Increase verbosity level to DEBUG") + parent.add_argument( + "-q", + action="count", + default=0, + help="Decrease verbosity level to WARNING or ERROR if specified twice", + ) + + parents = [parent] + for subcommand in SUBCOMMANDS: + subcommand.add(subparsers, parents) + + return parser + + +def main(): + parser = construct_parser() + args = parser.parse_args() + logfile, v_flag, q_count, mode, func = ( + pop_attr(args, attr) for attr in ["logfile", "v", "q", "mode", "func"] + ) + + if v_flag and q_count: + parser.error("The -v and -q options cannot be used together.") + + match logfile: + case None: + handler = logging.StreamHandler(sys.stderr) + case "default": + (LOG_DIR / mode).mkdir(parents=True, exist_ok=True) + handler = logging.FileHandler(str(LOG_DIR / mode / f"{NOW}.log")) + case _: + Path(logfile).parent.mkdir(parents=True, exist_ok=True) + handler = logging.FileHandler(logfile) + + verbosity = q_count * -1 if q_count else (1 if v_flag else 0) + logging_level = LOG_LEVELS.get(verbosity, logging.ERROR) + logging.basicConfig( + handlers=[handler], + format="%(asctime)s - %(levelname)s:%(name)s - %(message)s", + level=logging_level, + datefmt="%Y-%m-%dT%H:%M:%S", + force=True, + ) + + logger.info(f"Running in mode '{mode}' with args: {vars(args)}") + + func(args) diff --git a/chemprop-updated/chemprop/cli/predict.py b/chemprop-updated/chemprop/cli/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..1cfdeee0255e6a8a42782524170521f1fddefdaa --- /dev/null +++ b/chemprop-updated/chemprop/cli/predict.py @@ -0,0 +1,447 @@ +from argparse import ArgumentError, ArgumentParser, Namespace +import logging +from pathlib import Path +import sys +from typing import Iterator + +from lightning import pytorch as pl +import numpy as np +import pandas as pd +import torch + +from chemprop import data +from chemprop.cli.common import ( + add_common_args, + find_models, + process_common_args, + validate_common_args, +) +from chemprop.cli.utils import LookupAction, Subcommand, build_data_from_files, make_dataset +from chemprop.models.utils import load_model, load_output_columns +from chemprop.nn.metrics import LossFunctionRegistry +from chemprop.nn.predictors import EvidentialFFN, MulticlassClassificationFFN, MveFFN +from chemprop.uncertainty import ( + MVEWeightingCalibrator, + NoUncertaintyEstimator, + RegressionCalibrator, + RegressionEvaluator, + UncertaintyCalibratorRegistry, + UncertaintyEstimatorRegistry, + UncertaintyEvaluatorRegistry, +) +from chemprop.utils import Factory + +logger = logging.getLogger(__name__) + + +class PredictSubcommand(Subcommand): + COMMAND = "predict" + HELP = "use a pretrained chemprop model for prediction" + + @classmethod + def add_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser = add_common_args(parser) + return add_predict_args(parser) + + @classmethod + def func(cls, args: Namespace): + args = process_common_args(args) + validate_common_args(args) + args = process_predict_args(args) + main(args) + + +def add_predict_args(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "-i", + "--test-path", + required=True, + type=Path, + help="Path to an input CSV file containing SMILES", + ) + parser.add_argument( + "-o", + "--output", + "--preds-path", + type=Path, + help="Specify path to which predictions will be saved. If the file extension is .pkl, it will be saved as a pickle file. Otherwise, chemprop will save predictions as a CSV. If multiple models are used to make predictions, the average predictions will be saved in the file, and another file ending in '_individual' with the same file extension will save the predictions for each individual model, with the column names being the target names appended with the model index (e.g., '_model_').", + ) + parser.add_argument( + "--drop-extra-columns", + action="store_true", + help="Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns", + ) + parser.add_argument( + "--model-paths", + "--model-path", + required=True, + type=Path, + nargs="+", + help="Location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, will recursively search and predict on all found (.pt) models.", + ) + + unc_args = parser.add_argument_group("Uncertainty and calibration args") + unc_args.add_argument( + "--cal-path", type=Path, help="Path to data file to be used for uncertainty calibration." + ) + unc_args.add_argument( + "--uncertainty-method", + default="none", + action=LookupAction(UncertaintyEstimatorRegistry), + help="The method of calculating uncertainty.", + ) + unc_args.add_argument( + "--calibration-method", + action=LookupAction(UncertaintyCalibratorRegistry), + help="The method used for calibrating the uncertainty calculated with uncertainty method.", + ) + unc_args.add_argument( + "--evaluation-methods", + "--evaluation-method", + nargs="+", + action=LookupAction(UncertaintyEvaluatorRegistry), + help="The methods used for evaluating the uncertainty performance if the test data provided includes targets. Available methods are [nll, miscalibration_area, ence, spearman] or any available classification or multiclass metric.", + ) + # unc_args.add_argument( + # "--evaluation-scores-path", help="Location to save the results of uncertainty evaluations." + # ) + unc_args.add_argument( + "--uncertainty-dropout-p", + type=float, + default=0.1, + help="The probability to use for Monte Carlo dropout uncertainty estimation.", + ) + unc_args.add_argument( + "--dropout-sampling-size", + type=int, + default=10, + help="The number of samples to use for Monte Carlo dropout uncertainty estimation. Distinct from the dropout used during training.", + ) + unc_args.add_argument( + "--calibration-interval-percentile", + type=float, + default=95, + help="Sets the percentile used in the calibration methods. Must be in the range (1, 100).", + ) + unc_args.add_argument( + "--conformal-alpha", + type=float, + default=0.1, + help="Target error rate for conformal prediction. Must be in the range (0, 1).", + ) + # TODO: Decide if we want to implment this in v2.1.x + # unc_args.add_argument( + # "--regression-calibrator-metric", + # choices=["stdev", "interval"], + # help="Regression calibrators can output either a stdev or an inverval.", + # ) + unc_args.add_argument( + "--cal-descriptors-path", + nargs="+", + action="append", + help="Path to extra descriptors to concatenate to learned representation in calibration dataset.", + ) + # TODO: Add in v2.1.x + # unc_args.add_argument( + # "--calibration-phase-features-path", + # help=" ", + # ) + unc_args.add_argument( + "--cal-atom-features-path", + nargs="+", + action="append", + help="Path to the extra atom features in calibration dataset.", + ) + unc_args.add_argument( + "--cal-atom-descriptors-path", + nargs="+", + action="append", + help="Path to the extra atom descriptors in calibration dataset.", + ) + unc_args.add_argument( + "--cal-bond-features-path", + nargs="+", + action="append", + help="Path to the extra bond descriptors in calibration dataset.", + ) + + return parser + + +def process_predict_args(args: Namespace) -> Namespace: + if args.test_path.suffix not in [".csv"]: + raise ArgumentError( + argument=None, message=f"Input data must be a CSV file. Got {args.test_path}" + ) + if args.output is None: + args.output = args.test_path.parent / (args.test_path.stem + "_preds.csv") + if args.output.suffix not in [".csv", ".pkl"]: + raise ArgumentError( + argument=None, message=f"Output must be a CSV or Pickle file. Got {args.output}" + ) + return args + + +def prepare_data_loader( + args: Namespace, multicomponent: bool, is_calibration: bool, format_kwargs: dict +): + data_path = args.cal_path if is_calibration else args.test_path + descriptors_path = args.cal_descriptors_path if is_calibration else args.descriptors_path + atom_feats_path = args.cal_atom_features_path if is_calibration else args.atom_features_path + bond_feats_path = args.cal_bond_features_path if is_calibration else args.bond_features_path + atom_descs_path = ( + args.cal_atom_descriptors_path if is_calibration else args.atom_descriptors_path + ) + + featurization_kwargs = dict( + molecule_featurizers=args.molecule_featurizers, + keep_h=args.keep_h, + add_h=args.add_h, + ignore_chirality=args.ignore_chirality, + ) + + datas = build_data_from_files( + data_path, + **format_kwargs, + p_descriptors=descriptors_path, + p_atom_feats=atom_feats_path, + p_bond_feats=bond_feats_path, + p_atom_descs=atom_descs_path, + **featurization_kwargs, + ) + + dsets = [make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in datas] + dset = data.MulticomponentDataset(dsets) if multicomponent else dsets[0] + + return data.build_dataloader(dset, args.batch_size, args.num_workers, shuffle=False) + + +def make_prediction_for_models( + args: Namespace, model_paths: Iterator[Path], multicomponent: bool, output_path: Path +): + model = load_model(model_paths[0], multicomponent) + output_columns = load_output_columns(model_paths[0]) + bounded = any( + isinstance(model.criterion, LossFunctionRegistry[loss_function]) + for loss_function in LossFunctionRegistry.keys() + if "bounded" in loss_function + ) + format_kwargs = dict( + no_header_row=args.no_header_row, + smiles_cols=args.smiles_columns, + rxn_cols=args.reaction_columns, + ignore_cols=None, + splits_col=None, + weight_col=None, + bounded=bounded, + ) + format_kwargs["target_cols"] = output_columns if args.evaluation_methods is not None else [] + test_loader = prepare_data_loader(args, multicomponent, False, format_kwargs) + logger.info(f"test size: {len(test_loader.dataset)}") + if args.cal_path is not None: + format_kwargs["target_cols"] = output_columns + cal_loader = prepare_data_loader(args, multicomponent, True, format_kwargs) + logger.info(f"calibration size: {len(cal_loader.dataset)}") + + uncertainty_estimator = Factory.build( + UncertaintyEstimatorRegistry[args.uncertainty_method], + ensemble_size=args.dropout_sampling_size, + dropout=args.uncertainty_dropout_p, + ) + + models = [load_model(model_path, multicomponent) for model_path in model_paths] + trainer = pl.Trainer( + logger=False, enable_progress_bar=True, accelerator=args.accelerator, devices=args.devices + ) + test_individual_preds, test_individual_uncs = uncertainty_estimator( + test_loader, models, trainer + ) + test_preds = torch.mean(test_individual_preds, dim=0) + if not isinstance(uncertainty_estimator, NoUncertaintyEstimator): + test_uncs = torch.mean(test_individual_uncs, dim=0) + else: + test_uncs = None + + if args.calibration_method is not None: + uncertainty_calibrator = Factory.build( + UncertaintyCalibratorRegistry[args.calibration_method], + p=args.calibration_interval_percentile / 100, + alpha=args.conformal_alpha, + ) + cal_targets = cal_loader.dataset.Y + cal_mask = torch.from_numpy(np.isfinite(cal_targets)) + cal_targets = np.nan_to_num(cal_targets, nan=0.0) + cal_targets = torch.from_numpy(cal_targets) + cal_individual_preds, cal_individual_uncs = uncertainty_estimator( + cal_loader, models, trainer + ) + cal_preds = torch.mean(cal_individual_preds, dim=0) + cal_uncs = torch.mean(cal_individual_uncs, dim=0) + if isinstance(uncertainty_calibrator, MVEWeightingCalibrator): + uncertainty_calibrator.fit(cal_preds, cal_individual_uncs, cal_targets, cal_mask) + test_uncs = uncertainty_calibrator.apply(cal_individual_uncs) + else: + if isinstance(uncertainty_calibrator, RegressionCalibrator): + uncertainty_calibrator.fit(cal_preds, cal_uncs, cal_targets, cal_mask) + else: + uncertainty_calibrator.fit(cal_uncs, cal_targets, cal_mask) + test_uncs = uncertainty_calibrator.apply(test_uncs) + for i in range(test_individual_uncs.shape[0]): + test_individual_uncs[i] = uncertainty_calibrator.apply(test_individual_uncs[i]) + + if args.evaluation_methods is not None: + uncertainty_evaluators = [ + Factory.build(UncertaintyEvaluatorRegistry[method]) + for method in args.evaluation_methods + ] + logger.info("Uncertainty evaluation metric:") + for evaluator in uncertainty_evaluators: + test_targets = test_loader.dataset.Y + test_mask = torch.from_numpy(np.isfinite(test_targets)) + test_targets = np.nan_to_num(test_targets, nan=0.0) + test_targets = torch.from_numpy(test_targets) + if isinstance(evaluator, RegressionEvaluator): + metric_value = evaluator.evaluate(test_preds, test_uncs, test_targets, test_mask) + else: + metric_value = evaluator.evaluate(test_uncs, test_targets, test_mask) + logger.info(f"{evaluator.alias}: {metric_value.tolist()}") + + if args.uncertainty_method == "none" and ( + isinstance(model.predictor, MveFFN) or isinstance(model.predictor, EvidentialFFN) + ): + test_preds = test_preds[..., 0] + test_individual_preds = test_individual_preds[..., 0] + + if output_columns is None: + output_columns = [ + f"pred_{i}" for i in range(test_preds.shape[1]) + ] # TODO: need to improve this for cases like multi-task MVE and multi-task multiclass + + save_predictions(args, model, output_columns, test_preds, test_uncs, output_path) + + if len(model_paths) > 1: + save_individual_predictions( + args, + model, + model_paths, + output_columns, + test_individual_preds, + test_individual_uncs, + output_path, + ) + + +def save_predictions(args, model, output_columns, test_preds, test_uncs, output_path): + unc_columns = [f"{col}_unc" for col in output_columns] + + if isinstance(model.predictor, MulticlassClassificationFFN): + output_columns = output_columns + [f"{col}_prob" for col in output_columns] + predicted_class_labels = test_preds.argmax(axis=-1) + formatted_probability_strings = np.apply_along_axis( + lambda x: ",".join(map(str, x)), 2, test_preds.numpy() + ) + test_preds = np.concatenate( + (predicted_class_labels, formatted_probability_strings), axis=-1 + ) + + df_test = pd.read_csv( + args.test_path, header=None if args.no_header_row else "infer", index_col=False + ) + df_test[output_columns] = test_preds + + if args.uncertainty_method not in ["none", "classification"]: + df_test[unc_columns] = np.round(test_uncs, 6) + + if output_path.suffix == ".pkl": + df_test = df_test.reset_index(drop=True) + df_test.to_pickle(output_path) + else: + df_test.to_csv(output_path, index=False) + logger.info(f"Predictions saved to '{output_path}'") + + +def save_individual_predictions( + args, + model, + model_paths, + output_columns, + test_individual_preds, + test_individual_uncs, + output_path, +): + unc_columns = [ + f"{col}_unc_model_{i}" for i in range(len(model_paths)) for col in output_columns + ] + + if isinstance(model.predictor, MulticlassClassificationFFN): + output_columns = [ + item + for i in range(len(model_paths)) + for col in output_columns + for item in (f"{col}_model_{i}", f"{col}_prob_model_{i}") + ] + + predicted_class_labels = test_individual_preds.argmax(axis=-1) + formatted_probability_strings = np.apply_along_axis( + lambda x: ",".join(map(str, x)), 3, test_individual_preds.numpy() + ) + test_individual_preds = np.concatenate( + (predicted_class_labels, formatted_probability_strings), axis=-1 + ) + else: + output_columns = [ + f"{col}_model_{i}" for i in range(len(model_paths)) for col in output_columns + ] + + m, n, t = test_individual_preds.shape + test_individual_preds = np.transpose(test_individual_preds, (1, 0, 2)).reshape(n, m * t) + df_test = pd.read_csv( + args.test_path, header=None if args.no_header_row else "infer", index_col=False + ) + df_test[output_columns] = test_individual_preds + + if args.uncertainty_method not in ["none", "classification", "ensemble"]: + m, n, t = test_individual_uncs.shape + test_individual_uncs = np.transpose(test_individual_uncs, (1, 0, 2)).reshape(n, m * t) + df_test[unc_columns] = np.round(test_individual_uncs, 6) + + output_path = output_path.parent / Path( + str(args.output.stem) + "_individual" + str(output_path.suffix) + ) + if output_path.suffix == ".pkl": + df_test = df_test.reset_index(drop=True) + df_test.to_pickle(output_path) + else: + df_test.to_csv(output_path, index=False) + logger.info(f"Individual predictions saved to '{output_path}'") + for i, model_path in enumerate(model_paths): + logger.info( + f"Results from model path {model_path} are saved under the column name ending with 'model_{i}'" + ) + + +def main(args): + match (args.smiles_columns, args.reaction_columns): + case [None, None]: + n_components = 1 + case [_, None]: + n_components = len(args.smiles_columns) + case [None, _]: + n_components = len(args.reaction_columns) + case _: + n_components = len(args.smiles_columns) + len(args.reaction_columns) + + multicomponent = n_components > 1 + + model_paths = find_models(args.model_paths) + + make_prediction_for_models(args, model_paths, multicomponent, output_path=args.output) + + +if __name__ == "__main__": + parser = ArgumentParser() + parser = PredictSubcommand.add_args(parser) + + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) + args = parser.parse_args() + args = PredictSubcommand.func(args) diff --git a/chemprop-updated/chemprop/cli/train.py b/chemprop-updated/chemprop/cli/train.py new file mode 100644 index 0000000000000000000000000000000000000000..50ac2f4365af09c5c3f1c246dd90413e8b51e940 --- /dev/null +++ b/chemprop-updated/chemprop/cli/train.py @@ -0,0 +1,1343 @@ +from copy import deepcopy +from io import StringIO +import json +import logging +from pathlib import Path +import sys +from tempfile import TemporaryDirectory + +from configargparse import ArgumentError, ArgumentParser, Namespace +from lightning import pytorch as pl +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint +from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger +from lightning.pytorch.strategies import DDPStrategy +import numpy as np +import pandas as pd +from rich.console import Console +from rich.table import Column, Table +import torch +import torch.nn as nn + +from chemprop.cli.common import ( + add_common_args, + find_models, + process_common_args, + validate_common_args, +) +from chemprop.cli.conf import CHEMPROP_TRAIN_DIR, NOW +from chemprop.cli.utils import ( + LookupAction, + Subcommand, + build_data_from_files, + get_column_names, + make_dataset, + parse_indices, +) +from chemprop.cli.utils.args import uppercase +from chemprop.data import ( + MoleculeDataset, + MolGraphDataset, + MulticomponentDataset, + ReactionDatapoint, + SplitType, + build_dataloader, + make_split_indices, + split_data_by_indices, +) +from chemprop.data.datasets import _MolGraphDatasetMixin +from chemprop.models import MPNN, MulticomponentMPNN, save_model +from chemprop.nn import AggregationRegistry, LossFunctionRegistry, MetricRegistry, PredictorRegistry +from chemprop.nn.message_passing import ( + AtomMessagePassing, + BondMessagePassing, + MulticomponentMessagePassing, +) +from chemprop.nn.transforms import GraphTransform, ScaleTransform, UnscaleTransform +from chemprop.nn.utils import Activation +from chemprop.utils import Factory + +logger = logging.getLogger(__name__) + + +_CV_REMOVAL_ERROR = ( + "The -k/--num-folds argument was removed in v2.1.0 - use --num-replicates instead." +) + + +class TrainSubcommand(Subcommand): + COMMAND = "train" + HELP = "Train a chemprop model." + parser = None + + @classmethod + def add_args(cls, parser: ArgumentParser) -> ArgumentParser: + parser = add_common_args(parser) + parser = add_train_args(parser) + cls.parser = parser + return parser + + @classmethod + def func(cls, args: Namespace): + args = process_common_args(args) + validate_common_args(args) + args = process_train_args(args) + validate_train_args(args) + + args.output_dir.mkdir(exist_ok=True, parents=True) + config_path = args.output_dir / "config.toml" + save_config(cls.parser, args, config_path) + main(args) + + +def add_train_args(parser: ArgumentParser) -> ArgumentParser: + parser.add_argument( + "--config-path", + type=Path, + is_config_file=True, + help="Path to a configuration file (command line arguments override values in the configuration file)", + ) + parser.add_argument( + "-i", + "--data-path", + type=Path, + help="Path to an input CSV file containing SMILES and the associated target values", + ) + parser.add_argument( + "-o", + "--output-dir", + "--save-dir", + type=Path, + help="Directory where training outputs will be saved (defaults to ``CURRENT_DIRECTORY/chemprop_training/STEM_OF_INPUT/TIME_STAMP``)", + ) + parser.add_argument( + "--remove-checkpoints", + action="store_true", + help="Remove intermediate checkpoint files after training is complete.", + ) + + # TODO: Add in v2.1; see if we can tell lightning how often to log training loss + # parser.add_argument( + # "--log-frequency", + # type=int, + # default=10, + # help="The number of batches between each logging of the training loss.", + # ) + + transfer_args = parser.add_argument_group("transfer learning args") + transfer_args.add_argument( + "--checkpoint", + type=Path, + nargs="+", + help="Path to checkpoint(s) or model file(s) for loading and overwriting weights. Accepts a single pre-trained model checkpoint (.ckpt), a single model file (.pt), a directory containing such files, or a list of paths and directories. If a directory is provided, it will recursively search for and use all (.pt) files found for prediction.", + ) + transfer_args.add_argument( + "--freeze-encoder", + action="store_true", + help="Freeze the message passing layer from the checkpoint model (specified by ``--checkpoint``).", + ) + transfer_args.add_argument( + "--model-frzn", + help="Path to model checkpoint file to be loaded for overwriting and freezing weights. By default, all MPNN weights are frozen with this option.", + ) + transfer_args.add_argument( + "--frzn-ffn-layers", + type=int, + default=0, + help="Freeze the first ``n`` layers of the FFN from the checkpoint model (specified by ``--checkpoint``). The message passing layer should also be frozen with ``--freeze-encoder``.", + ) + # transfer_args.add_argument( + # "--freeze-first-only", + # action="store_true", + # help="Determines whether or not to use checkpoint_frzn for just the first encoder. Default (False) is to use the checkpoint to freeze all encoders. (only relevant for number_of_molecules > 1, where checkpoint model has number_of_molecules = 1)", + # ) + + # TODO: Add in v2.1 + # parser.add_argument( + # "--resume-experiment", + # action="store_true", + # help="Whether to resume the experiment. Loads test results from any folds that have already been completed and skips training those folds.", + # ) + # parser.add_argument( + # "--config-path", + # help="Path to a :code:`.json` file containing arguments. Any arguments present in the config file will override arguments specified via the command line or by the defaults.", + # ) + parser.add_argument( + "--ensemble-size", + type=int, + default=1, + help="Number of models in ensemble for each splitting of data", + ) + + # TODO: Add in v2.2 + # abt_args = parser.add_argument_group("atom/bond target args") + # abt_args.add_argument( + # "--is-atom-bond-targets", + # action="store_true", + # help="Whether this is atomic/bond properties prediction.", + # ) + # abt_args.add_argument( + # "--no-adding-bond-types", + # action="store_true", + # help="Whether the bond types determined by RDKit molecules added to the output of bond targets. This option is intended to be used with the :code:`is_atom_bond_targets`.", + # ) + # abt_args.add_argument( + # "--keeping-atom-map", + # action="store_true", + # help="Whether RDKit molecules keep the original atom mapping. This option is intended to be used when providing atom-mapped SMILES with the :code:`is_atom_bond_targets`.", + # ) + # abt_args.add_argument( + # "--no-shared-atom-bond-ffn", + # action="store_true", + # help="Whether the FFN weights for atom and bond targets should be independent between tasks.", + # ) + # abt_args.add_argument( + # "--weights-ffn-num-layers", + # type=int, + # default=2, + # help="Number of layers in FFN for determining weights used in constrained targets.", + # ) + + mp_args = parser.add_argument_group("message passing") + mp_args.add_argument( + "--message-hidden-dim", type=int, default=300, help="Hidden dimension of the messages" + ) + mp_args.add_argument( + "--message-bias", action="store_true", help="Add bias to the message passing layers" + ) + mp_args.add_argument("--depth", type=int, default=3, help="Number of message passing steps") + mp_args.add_argument( + "--undirected", + action="store_true", + help="Pass messages on undirected bonds/edges (always sum the two relevant bond vectors)", + ) + mp_args.add_argument( + "--dropout", + type=float, + default=0.0, + help="Dropout probability in message passing/FFN layers", + ) + mp_args.add_argument( + "--mpn-shared", + action="store_true", + help="Whether to use the same message passing neural network for all input molecules (only relevant if ``number_of_molecules`` > 1)", + ) + mp_args.add_argument( + "--activation", + type=uppercase, + default="RELU", + choices=list(Activation.keys()), + help="Activation function in message passing/FFN layers", + ) + mp_args.add_argument( + "--aggregation", + "--agg", + default="norm", + action=LookupAction(AggregationRegistry), + help="Aggregation mode to use during graph predictor", + ) + mp_args.add_argument( + "--aggregation-norm", + type=float, + default=100, + help="Normalization factor by which to divide summed up atomic features for ``norm`` aggregation", + ) + mp_args.add_argument( + "--atom-messages", action="store_true", help="Pass messages on atoms rather than bonds." + ) + + # TODO: Add in v2.1 + # mpsolv_args = parser.add_argument_group("message passing with solvent") + # mpsolv_args.add_argument( + # "--reaction-solvent", + # action="store_true", + # help="Whether to adjust the MPNN layer to take as input a reaction and a molecule, and to encode them with separate MPNNs.", + # ) + # mpsolv_args.add_argument( + # "--bias-solvent", + # action="store_true", + # help="Whether to add bias to linear layers for solvent MPN if :code:`reaction_solvent` is True.", + # ) + # mpsolv_args.add_argument( + # "--hidden-size-solvent", + # type=int, + # default=300, + # help="Dimensionality of hidden layers in solvent MPN if :code:`reaction_solvent` is True.", + # ) + # mpsolv_args.add_argument( + # "--depth-solvent", + # type=int, + # default=3, + # help="Number of message passing steps for solvent if :code:`reaction_solvent` is True.", + # ) + + ffn_args = parser.add_argument_group("FFN args") + ffn_args.add_argument( + "--ffn-hidden-dim", type=int, default=300, help="Hidden dimension in the FFN top model" + ) + ffn_args.add_argument( # TODO: the default in v1 was 2. (see weights_ffn_num_layers option) Do we really want the default to now be 1? + "--ffn-num-layers", type=int, default=1, help="Number of layers in FFN top model" + ) + # TODO: Decide if we want to implment this in v2 + # ffn_args.add_argument( + # "--features-only", + # action="store_true", + # help="Use only the additional features in an FFN, no graph network.", + # ) + + extra_mpnn_args = parser.add_argument_group("extra MPNN args") + extra_mpnn_args.add_argument( + "--batch-norm", action="store_true", help="Turn on batch normalization after aggregation" + ) + extra_mpnn_args.add_argument( + "--multiclass-num-classes", + type=int, + default=3, + help="Number of classes when running multiclass classification", + ) + # TODO: Add in v2.1 + # extra_mpnn_args.add_argument( + # "--spectral-activation", + # default="exp", + # choices=["softplus", "exp"], + # help="Indicates which function to use in task_type spectra training to constrain outputs to be positive.", + # ) + + train_data_args = parser.add_argument_group("training input data args") + train_data_args.add_argument( + "-w", + "--weight-column", + help="Name of the column in the input CSV containing individual data weights", + ) + train_data_args.add_argument( + "--target-columns", + nargs="+", + help="Name of the columns containing target values (by default, uses all columns except the SMILES column and the ``ignore_columns``)", + ) + train_data_args.add_argument( + "--ignore-columns", + nargs="+", + help="Name of the columns to ignore when ``target_columns`` is not provided", + ) + train_data_args.add_argument( + "--no-cache", + action="store_true", + help="Turn off caching the featurized ``MolGraph`` s at the beginning of training", + ) + train_data_args.add_argument( + "--splits-column", + help="Name of the column in the input CSV file containing 'train', 'val', or 'test' for each row.", + ) + # TODO: Add in v2.1 + # train_data_args.add_argument( + # "--spectra-phase-mask-path", + # help="Path to a file containing a phase mask array, used for excluding particular regions in spectra predictions.", + # ) + + train_args = parser.add_argument_group("training args") + train_args.add_argument( + "-t", + "--task-type", + default="regression", + action=LookupAction(PredictorRegistry), + help="Type of dataset (determines the default loss function used during training, defaults to ``regression``)", + ) + train_args.add_argument( + "-l", + "--loss-function", + action=LookupAction(LossFunctionRegistry), + help="Loss function to use during training (will use the default loss function for the given task type if not specified)", + ) + train_args.add_argument( + "--v-kl", + "--evidential-regularization", + type=float, + default=0.0, + help="Specify the value used in regularization for evidential loss function. The default value recommended by Soleimany et al. (2021) is 0.2. However, the optimal value is dataset-dependent, so it is recommended that users test different values to find the best value for their model.", + ) + + train_args.add_argument( + "--eps", type=float, default=1e-8, help="Evidential regularization epsilon" + ) + train_args.add_argument( + "--alpha", type=float, default=0.1, help="Target error bounds for quantile interval loss" + ) + # TODO: Add in v2.1 + # train_args.add_argument( # TODO: Is threshold the same thing as the spectra target floor? I'm not sure but combined them. + # "-T", + # "--threshold", + # "--spectra-target-floor", + # type=float, + # default=1e-8, + # help="spectral threshold limit. v1 help string: Values in targets for dataset type spectra are replaced with this value, intended to be a small positive number used to enforce positive values.", + # ) + train_args.add_argument( + "--metrics", + "--metric", + nargs="+", + action=LookupAction(MetricRegistry), + help="Specify the evaluation metrics. If unspecified, chemprop will use the following metrics for given dataset types: regression -> ``rmse``, classification -> ``roc``, multiclass -> ``ce`` ('cross entropy'), spectral -> ``sid``. If multiple metrics are provided, the 0-th one will be used for early stopping and checkpointing.", + ) + train_args.add_argument( + "--tracking-metric", + default="val_loss", + help="The metric to track for early stopping and checkpointing. Defaults to the criterion used during training.", + ) + train_args.add_argument( + "--show-individual-scores", + action="store_true", + help="Show all scores for individual targets, not just average, at the end.", + ) + train_args.add_argument( + "--task-weights", + nargs="+", + type=float, + help="Weights to apply for whole tasks in the loss function", + ) + train_args.add_argument( + "--warmup-epochs", + type=int, + default=2, + help="Number of epochs during which learning rate increases linearly from ``init_lr`` to ``max_lr`` (afterwards, learning rate decreases exponentially from ``max_lr`` to ``final_lr``)", + ) + + train_args.add_argument("--init-lr", type=float, default=1e-4, help="Initial learning rate.") + train_args.add_argument("--max-lr", type=float, default=1e-3, help="Maximum learning rate.") + train_args.add_argument("--final-lr", type=float, default=1e-4, help="Final learning rate.") + train_args.add_argument("--epochs", type=int, default=50, help="Number of epochs to train over") + train_args.add_argument( + "--patience", + type=int, + default=None, + help="Number of epochs to wait for improvement before early stopping", + ) + train_args.add_argument( + "--grad-clip", + type=float, + help="Passed directly to the lightning trainer which controls grad clipping (see the ``Trainer()`` docstring for details)", + ) + train_args.add_argument( + "--class-balance", + action="store_true", + help="Ensures each training batch contains an equal number of positive and negative samples.", + ) + + split_args = parser.add_argument_group("split args") + split_args.add_argument( + "--split", + "--split-type", + type=uppercase, + default="RANDOM", + choices=list(SplitType.keys()), + help="Method of splitting the data into train/val/test (case insensitive)", + ) + split_args.add_argument( + "--split-sizes", + type=float, + nargs=3, + default=[0.8, 0.1, 0.1], + help="Split proportions for train/validation/test sets", + ) + split_args.add_argument( + "--split-key-molecule", + type=int, + default=0, + help="Specify the index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used (e.g., ``scaffold_balanced`` or ``random_with_repeated_smiles``). Note that this index begins with zero for the first molecule.", + ) + split_args.add_argument("--num-replicates", type=int, default=1, help="Number of replicates.") + split_args.add_argument("-k", "--num-folds", help=_CV_REMOVAL_ERROR) + split_args.add_argument( + "--save-smiles-splits", + action="store_true", + help="Whether to store the SMILES in each train/val/test split", + ) + split_args.add_argument( + "--splits-file", + type=Path, + help="Path to a JSON file containing pre-defined splits for the input data, formatted as a list of dictionaries with keys ``train``, ``val``, and ``test`` and values as lists of indices or formatted strings (e.g. [0, 1, 2, 4] or '0-2,4')", + ) + split_args.add_argument( + "--data-seed", + type=int, + default=0, + help="Specify the random seed to use when splitting data into train/val/test sets. When ``--num-replicates`` > 1, the first replicate uses this seed and all subsequent replicates add 1 to the seed (also used for shuffling data in ``build_dataloader`` when ``shuffle`` is True).", + ) + + parser.add_argument( + "--pytorch-seed", + type=int, + default=None, + help="Seed for PyTorch randomness (e.g., random initial weights)", + ) + + return parser + + +def process_train_args(args: Namespace) -> Namespace: + if args.output_dir is None: + args.output_dir = CHEMPROP_TRAIN_DIR / args.data_path.stem / NOW + + return args + + +def validate_train_args(args): + if args.config_path is None and args.data_path is None: + raise ArgumentError(argument=None, message="Data path must be provided for training.") + + if args.num_folds is not None: # i.e. user-specified + raise ArgumentError(argument=None, message=_CV_REMOVAL_ERROR) + + if args.data_path.suffix not in [".csv"]: + raise ArgumentError( + argument=None, message=f"Input data must be a CSV file. Got {args.data_path}" + ) + + if args.epochs != -1 and args.epochs <= args.warmup_epochs: + raise ArgumentError( + argument=None, + message=f"The number of epochs should be higher than the number of epochs during warmup. Got {args.epochs} epochs and {args.warmup_epochs} warmup epochs", + ) + + # TODO: model_frzn is deprecated and then remove in v2.2 + if args.checkpoint is not None and args.model_frzn is not None: + raise ArgumentError( + argument=None, + message="`--checkpoint` and `--model-frzn` cannot be used at the same time.", + ) + + if "--model-frzn" in sys.argv: + logger.warning( + "`--model-frzn` is deprecated and will be removed in v2.2. " + "Please use `--checkpoint` with `--freeze-encoder` instead." + ) + + if args.freeze_encoder and args.checkpoint is None: + raise ArgumentError( + argument=None, + message="`--freeze-encoder` can only be used when `--checkpoint` is used.", + ) + + if args.frzn_ffn_layers > 0: + if args.checkpoint is None and args.model_frzn is None: + raise ArgumentError( + argument=None, + message="`--frzn-ffn-layers` can only be used when `--checkpoint` or `--model-frzn` (depreciated in v2.1) is used.", + ) + if args.checkpoint is not None and not args.freeze_encoder: + raise ArgumentError( + argument=None, + message="To freeze the first `n` layers of the FFN via `--frzn-ffn-layers`. The message passing layer should also be frozen with `--freeze-encoder`.", + ) + + if args.class_balance and args.task_type != "classification": + raise ArgumentError( + argument=None, message="Class balance is only applicable for classification tasks." + ) + + valid_tracking_metrics = ( + args.metrics or [PredictorRegistry[args.task_type]._T_default_metric.alias] + ) + ["val_loss"] + if args.tracking_metric not in valid_tracking_metrics: + raise ArgumentError( + argument=None, + message=f"Tracking metric must be one of {','.join(valid_tracking_metrics)}. " + f"Got {args.tracking_metric}. Additional tracking metric options can be specified with " + "the `--metrics` flag.", + ) + + input_cols, target_cols = get_column_names( + args.data_path, + args.smiles_columns, + args.reaction_columns, + args.target_columns, + args.ignore_columns, + args.splits_column, + args.weight_column, + args.no_header_row, + ) + + args.input_columns = input_cols + args.target_columns = target_cols + + return args + + +def normalize_inputs(train_dset, val_dset, args): + multicomponent = isinstance(train_dset, MulticomponentDataset) + num_components = train_dset.n_components if multicomponent else 1 + + X_d_transform = None + V_f_transforms = [nn.Identity()] * num_components + E_f_transforms = [nn.Identity()] * num_components + V_d_transforms = [None] * num_components + graph_transforms = [] + + d_xd = train_dset.d_xd + d_vf = train_dset.d_vf + d_ef = train_dset.d_ef + d_vd = train_dset.d_vd + + if d_xd > 0 and not args.no_descriptor_scaling: + scaler = train_dset.normalize_inputs("X_d") + val_dset.normalize_inputs("X_d", scaler) + + scaler = scaler if not isinstance(scaler, list) else scaler[0] + + if scaler is not None: + logger.info( + f"Descriptors: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}" + ) + X_d_transform = ScaleTransform.from_standard_scaler(scaler) + + if d_vf > 0 and not args.no_atom_feature_scaling: + scaler = train_dset.normalize_inputs("V_f") + val_dset.normalize_inputs("V_f", scaler) + + scalers = [scaler] if not isinstance(scaler, list) else scaler + + for i, scaler in enumerate(scalers): + if scaler is None: + continue + + logger.info( + f"Atom features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}" + ) + featurizer = ( + train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer + ) + V_f_transforms[i] = ScaleTransform.from_standard_scaler( + scaler, pad=featurizer.atom_fdim - featurizer.extra_atom_fdim + ) + + if d_ef > 0 and not args.no_bond_feature_scaling: + scaler = train_dset.normalize_inputs("E_f") + val_dset.normalize_inputs("E_f", scaler) + + scalers = [scaler] if not isinstance(scaler, list) else scaler + + for i, scaler in enumerate(scalers): + if scaler is None: + continue + + logger.info( + f"Bond features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}" + ) + featurizer = ( + train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer + ) + E_f_transforms[i] = ScaleTransform.from_standard_scaler( + scaler, pad=featurizer.bond_fdim - featurizer.extra_bond_fdim + ) + + for V_f_transform, E_f_transform in zip(V_f_transforms, E_f_transforms): + graph_transforms.append(GraphTransform(V_f_transform, E_f_transform)) + + if d_vd > 0 and not args.no_atom_descriptor_scaling: + scaler = train_dset.normalize_inputs("V_d") + val_dset.normalize_inputs("V_d", scaler) + + scalers = [scaler] if not isinstance(scaler, list) else scaler + + for i, scaler in enumerate(scalers): + if scaler is None: + continue + + logger.info( + f"Atom descriptors for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}" + ) + V_d_transforms[i] = ScaleTransform.from_standard_scaler(scaler) + + return X_d_transform, graph_transforms, V_d_transforms + + +def load_and_use_pretrained_model_scalers(model_path: Path, train_dset, val_dset) -> None: + if isinstance(train_dset, MulticomponentDataset): + _model = MulticomponentMPNN.load_from_file(model_path) + blocks = _model.message_passing.blocks + train_dsets = train_dset.datasets + val_dsets = val_dset.datasets + else: + _model = MPNN.load_from_file(model_path) + blocks = [_model.message_passing] + train_dsets = [train_dset] + val_dsets = [val_dset] + + for i in range(len(blocks)): + if isinstance(_model.X_d_transform, ScaleTransform): + scaler = _model.X_d_transform.to_standard_scaler() + train_dsets[i].normalize_inputs("X_d", scaler) + val_dsets[i].normalize_inputs("X_d", scaler) + + if isinstance(blocks[i].graph_transform, GraphTransform): + if isinstance(blocks[i].graph_transform.V_transform, ScaleTransform): + V_anti_pad = ( + train_dsets[i].featurizer.atom_fdim - train_dsets[i].featurizer.extra_atom_fdim + ) + scaler = blocks[i].graph_transform.V_transform.to_standard_scaler( + anti_pad=V_anti_pad + ) + train_dsets[i].normalize_inputs("V_f", scaler) + val_dsets[i].normalize_inputs("V_f", scaler) + if isinstance(blocks[i].graph_transform.E_transform, ScaleTransform): + E_anti_pad = ( + train_dsets[i].featurizer.bond_fdim - train_dsets[i].featurizer.extra_bond_fdim + ) + scaler = blocks[i].graph_transform.E_transform.to_standard_scaler( + anti_pad=E_anti_pad + ) + train_dsets[i].normalize_inputs("E_f", scaler) + val_dsets[i].normalize_inputs("E_f", scaler) + + if isinstance(blocks[i].V_d_transform, ScaleTransform): + scaler = blocks[i].V_d_transform.to_standard_scaler() + train_dsets[i].normalize_inputs("V_d", scaler) + val_dsets[i].normalize_inputs("V_d", scaler) + + if isinstance(_model.predictor.output_transform, UnscaleTransform): + scaler = _model.predictor.output_transform.to_standard_scaler() + train_dset.normalize_targets(scaler) + val_dset.normalize_targets(scaler) + + +def save_config(parser: ArgumentParser, args: Namespace, config_path: Path): + config_args = deepcopy(args) + for key, value in vars(config_args).items(): + if isinstance(value, Path): + setattr(config_args, key, str(value)) + + for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]: + if getattr(config_args, key) is not None: + for index, path in getattr(config_args, key).items(): + getattr(config_args, key)[index] = str(path) + + parser.write_config_file(parsed_namespace=config_args, output_file_paths=[str(config_path)]) + + +def save_smiles_splits(args: Namespace, output_dir, train_dset, val_dset, test_dset): + match (args.smiles_columns, args.reaction_columns): + case [_, None]: + column_labels = deepcopy(args.smiles_columns) + case [None, _]: + column_labels = deepcopy(args.reaction_columns) + case _: + column_labels = deepcopy(args.smiles_columns) + column_labels.extend(args.reaction_columns) + + train_smis = train_dset.names + df_train = pd.DataFrame(train_smis, columns=column_labels) + df_train.to_csv(output_dir / "train_smiles.csv", index=False) + + val_smis = val_dset.names + df_val = pd.DataFrame(val_smis, columns=column_labels) + df_val.to_csv(output_dir / "val_smiles.csv", index=False) + + if test_dset is not None: + test_smis = test_dset.names + df_test = pd.DataFrame(test_smis, columns=column_labels) + df_test.to_csv(output_dir / "test_smiles.csv", index=False) + + +def build_splits(args, format_kwargs, featurization_kwargs): + """build the train/val/test splits""" + logger.info(f"Pulling data from file: {args.data_path}") + all_data = build_data_from_files( + args.data_path, + p_descriptors=args.descriptors_path, + p_atom_feats=args.atom_features_path, + p_bond_feats=args.bond_features_path, + p_atom_descs=args.atom_descriptors_path, + **format_kwargs, + **featurization_kwargs, + ) + + if args.splits_column is not None: + df = pd.read_csv( + args.data_path, header=None if args.no_header_row else "infer", index_col=False + ) + grouped = df.groupby(df[args.splits_column].str.lower()) + train_indices = grouped.groups.get("train", pd.Index([])).tolist() + val_indices = grouped.groups.get("val", pd.Index([])).tolist() + test_indices = grouped.groups.get("test", pd.Index([])).tolist() + train_indices, val_indices, test_indices = [train_indices], [val_indices], [test_indices] + + elif args.splits_file is not None: + with open(args.splits_file, "rb") as json_file: + split_idxss = json.load(json_file) + train_indices = [parse_indices(d["train"]) for d in split_idxss] + val_indices = [parse_indices(d["val"]) for d in split_idxss] + test_indices = [parse_indices(d["test"]) for d in split_idxss] + args.num_replicates = len(split_idxss) + + else: + splitting_data = all_data[args.split_key_molecule] + if isinstance(splitting_data[0], ReactionDatapoint): + splitting_mols = [datapoint.rct for datapoint in splitting_data] + else: + splitting_mols = [datapoint.mol for datapoint in splitting_data] + train_indices, val_indices, test_indices = make_split_indices( + splitting_mols, args.split, args.split_sizes, args.data_seed, args.num_replicates + ) + + train_data, val_data, test_data = split_data_by_indices( + all_data, train_indices, val_indices, test_indices + ) + for i_split in range(len(train_data)): + sizes = [len(train_data[i_split][0]), len(val_data[i_split][0]), len(test_data[i_split][0])] + logger.info(f"train/val/test split_{i_split} sizes: {sizes}") + + return train_data, val_data, test_data + + +def summarize( + target_cols: list[str], task_type: str, dataset: _MolGraphDatasetMixin +) -> tuple[list, list]: + if task_type in [ + "regression", + "regression-mve", + "regression-evidential", + "regression-quantile", + ]: + if isinstance(dataset, MulticomponentDataset): + y = dataset.datasets[0].Y + else: + y = dataset.Y + y_mean = np.nanmean(y, axis=0) + y_std = np.nanstd(y, axis=0) + y_median = np.nanmedian(y, axis=0) + mean_dev_abs = np.abs(y - y_mean) + num_targets = np.sum(~np.isnan(y), axis=0) + frac_1_sigma = np.sum((mean_dev_abs < y_std), axis=0) / num_targets + frac_2_sigma = np.sum((mean_dev_abs < 2 * y_std), axis=0) / num_targets + + column_headers = ["Statistic"] + [f"Value ({target_cols[i]})" for i in range(y.shape[1])] + table_rows = [ + ["Num. smiles"] + [f"{len(y)}" for i in range(y.shape[1])], + ["Num. targets"] + [f"{num_targets[i]}" for i in range(y.shape[1])], + ["Num. NaN"] + [f"{len(y) - num_targets[i]}" for i in range(y.shape[1])], + ["Mean"] + [f"{mean:0.3g}" for mean in y_mean], + ["Std. dev."] + [f"{std:0.3g}" for std in y_std], + ["Median"] + [f"{median:0.3g}" for median in y_median], + ["% within 1 s.d."] + [f"{sigma:0.0%}" for sigma in frac_1_sigma], + ["% within 2 s.d."] + [f"{sigma:0.0%}" for sigma in frac_2_sigma], + ] + return (column_headers, table_rows) + elif task_type in [ + "classification", + "classification-dirichlet", + "multiclass", + "multiclass-dirichlet", + ]: + if isinstance(dataset, MulticomponentDataset): + y = dataset.datasets[0].Y + else: + y = dataset.Y + + mask = np.isnan(y) + classes = np.sort(np.unique(y[~mask])) + + class_counts = np.stack([(classes[:, None] == y[:, i]).sum(1) for i in range(y.shape[1])]) + class_fracs = class_counts / y.shape[0] + nan_count = np.nansum(mask, axis=0) + nan_frac = nan_count / y.shape[0] + + column_headers = ["Class"] + [f"Count/Percent {target_cols[i]}" for i in range(y.shape[1])] + + table_rows = [ + [f"{k}"] + [f"{class_counts[j, i]}/{class_fracs[j, i]:0.0%}" for j in range(y.shape[1])] + for i, k in enumerate(classes) + ] + + nan_row = ["NaN"] + [f"{nan_count[i]}/{nan_frac[i]:0.0%}" for i in range(y.shape[1])] + table_rows.append(nan_row) + + total_row = ["Total"] + [f"{y.shape[0]}/{100.00}%" for i in range(y.shape[1])] + table_rows.append(total_row) + + return (column_headers, table_rows) + else: + raise ValueError(f"unsupported task type! Task type '{task_type}' was not recognized.") + + +def build_table(column_headers: list[str], table_rows: list[str], title: str | None = None) -> str: + right_justified_columns = [ + Column(header=column_header, justify="right") for column_header in column_headers + ] + table = Table(*right_justified_columns, title=title) + for row in table_rows: + table.add_row(*row) + + console = Console(record=True, file=StringIO(), width=200) + console.print(table) + return console.export_text() + + +def build_datasets(args, train_data, val_data, test_data): + """build the train/val/test datasets, where :attr:`test_data` may be None""" + multicomponent = len(train_data) > 1 + if multicomponent: + train_dsets = [ + make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode) + for data in train_data + ] + val_dsets = [ + make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode) + for data in val_data + ] + train_dset = MulticomponentDataset(train_dsets) + val_dset = MulticomponentDataset(val_dsets) + if len(test_data[0]) > 0: + test_dsets = [ + make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode) + for data in test_data + ] + test_dset = MulticomponentDataset(test_dsets) + else: + test_dset = None + else: + train_data = train_data[0] + val_data = val_data[0] + test_data = test_data[0] + train_dset = make_dataset(train_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode) + val_dset = make_dataset(val_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode) + if len(test_data) > 0: + test_dset = make_dataset(test_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode) + else: + test_dset = None + if args.task_type != "spectral": + for dataset, label in zip( + [train_dset, val_dset, test_dset], ["Training", "Validation", "Test"] + ): + column_headers, table_rows = summarize(args.target_columns, args.task_type, dataset) + output = build_table(column_headers, table_rows, f"Summary of {label} Data") + logger.info("\n" + output) + + return train_dset, val_dset, test_dset + + +def build_model( + args, + train_dset: MolGraphDataset | MulticomponentDataset, + output_transform: UnscaleTransform, + input_transforms: tuple[ScaleTransform, list[GraphTransform], list[ScaleTransform]], +) -> MPNN: + mp_cls = AtomMessagePassing if args.atom_messages else BondMessagePassing + + X_d_transform, graph_transforms, V_d_transforms = input_transforms + if isinstance(train_dset, MulticomponentDataset): + mp_blocks = [ + mp_cls( + train_dset.datasets[i].featurizer.atom_fdim, + train_dset.datasets[i].featurizer.bond_fdim, + d_h=args.message_hidden_dim, + d_vd=( + train_dset.datasets[i].d_vd + if isinstance(train_dset.datasets[i], MoleculeDataset) + else 0 + ), + bias=args.message_bias, + depth=args.depth, + undirected=args.undirected, + dropout=args.dropout, + activation=args.activation, + V_d_transform=V_d_transforms[i], + graph_transform=graph_transforms[i], + ) + for i in range(train_dset.n_components) + ] + if args.mpn_shared: + if args.reaction_columns is not None and args.smiles_columns is not None: + raise ArgumentError( + argument=None, + message="Cannot use shared MPNN with both molecule and reaction data.", + ) + + mp_block = MulticomponentMessagePassing(mp_blocks, train_dset.n_components, args.mpn_shared) + # NOTE(degraff): this if/else block should be handled by the init of MulticomponentMessagePassing + # if args.mpn_shared: + # mp_block = MulticomponentMessagePassing(mp_blocks[0], n_components, args.mpn_shared) + # else: + d_xd = train_dset.datasets[0].d_xd + n_tasks = train_dset.datasets[0].Y.shape[1] + mpnn_cls = MulticomponentMPNN + else: + mp_block = mp_cls( + train_dset.featurizer.atom_fdim, + train_dset.featurizer.bond_fdim, + d_h=args.message_hidden_dim, + d_vd=train_dset.d_vd if isinstance(train_dset, MoleculeDataset) else 0, + bias=args.message_bias, + depth=args.depth, + undirected=args.undirected, + dropout=args.dropout, + activation=args.activation, + V_d_transform=V_d_transforms[0], + graph_transform=graph_transforms[0], + ) + d_xd = train_dset.d_xd + n_tasks = train_dset.Y.shape[1] + mpnn_cls = MPNN + + agg = Factory.build(AggregationRegistry[args.aggregation], norm=args.aggregation_norm) + predictor_cls = PredictorRegistry[args.task_type] + if args.loss_function is not None: + task_weights = torch.ones(n_tasks) if args.task_weights is None else args.task_weights + criterion = Factory.build( + LossFunctionRegistry[args.loss_function], + task_weights=task_weights, + v_kl=args.v_kl, + # threshold=args.threshold, TODO: Add in v2.1 + eps=args.eps, + alpha=args.alpha, + ) + else: + criterion = None + if args.metrics is not None: + metrics = [Factory.build(MetricRegistry[metric]) for metric in args.metrics] + else: + metrics = None + + predictor = Factory.build( + predictor_cls, + input_dim=mp_block.output_dim + d_xd, + n_tasks=n_tasks, + hidden_dim=args.ffn_hidden_dim, + n_layers=args.ffn_num_layers, + dropout=args.dropout, + activation=args.activation, + criterion=criterion, + task_weights=args.task_weights, + n_classes=args.multiclass_num_classes, + output_transform=output_transform, + # spectral_activation=args.spectral_activation, TODO: Add in v2.1 + ) + + if args.loss_function is None: + logger.info( + f"No loss function was specified! Using class default: {predictor_cls._T_default_criterion}" + ) + + return mpnn_cls( + mp_block, + agg, + predictor, + args.batch_norm, + metrics, + args.warmup_epochs, + args.init_lr, + args.max_lr, + args.final_lr, + X_d_transform=X_d_transform, + ) + + +def train_model( + args, train_loader, val_loader, test_loader, output_dir, output_transform, input_transforms +): + if args.checkpoint is not None: + model_paths = find_models(args.checkpoint) + if args.ensemble_size != len(model_paths): + logger.warning( + f"The number of models in ensemble for each splitting of data is set to {len(model_paths)}." + ) + args.ensemble_size = len(model_paths) + + for model_idx in range(args.ensemble_size): + model_output_dir = output_dir / f"model_{model_idx}" + model_output_dir.mkdir(exist_ok=True, parents=True) + + if args.pytorch_seed is None: + seed = torch.seed() + deterministic = False + else: + seed = args.pytorch_seed + model_idx + deterministic = True + + torch.manual_seed(seed) + + if args.checkpoint or args.model_frzn is not None: + mpnn_cls = ( + MulticomponentMPNN + if isinstance(train_loader.dataset, MulticomponentDataset) + else MPNN + ) + model_path = model_paths[model_idx] if args.checkpoint else args.model_frzn + model = mpnn_cls.load_from_file(model_path) + + if args.checkpoint: + model.apply( + lambda m: setattr(m, "p", args.dropout) + if isinstance(m, torch.nn.Dropout) + else None + ) + + # TODO: model_frzn is deprecated and then remove in v2.2 + if args.model_frzn or args.freeze_encoder: + model.message_passing.apply(lambda module: module.requires_grad_(False)) + model.message_passing.eval() + model.bn.apply(lambda module: module.requires_grad_(False)) + model.bn.eval() + for idx in range(args.frzn_ffn_layers): + model.predictor.ffn[idx].requires_grad_(False) + model.predictor.ffn[idx + 1].eval() + else: + model = build_model(args, train_loader.dataset, output_transform, input_transforms) + logger.info(model) + + try: + trainer_logger = TensorBoardLogger( + model_output_dir, "trainer_logs", default_hp_metric=False + ) + except ModuleNotFoundError as e: + logger.warning( + f"Unable to import TensorBoardLogger, reverting to CSVLogger (original error: {e})." + ) + trainer_logger = CSVLogger(model_output_dir, "trainer_logs") + + if args.tracking_metric == "val_loss": + T_tracking_metric = model.criterion.__class__ + tracking_metric = args.tracking_metric + else: + T_tracking_metric = MetricRegistry[args.tracking_metric] + tracking_metric = "val/" + args.tracking_metric + + monitor_mode = "max" if T_tracking_metric.higher_is_better else "min" + logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'") + + if args.remove_checkpoints: + temp_dir = TemporaryDirectory() + checkpoint_dir = Path(temp_dir.name) + else: + checkpoint_dir = model_output_dir + + checkpoint_filename = ( + f"best-epoch={{epoch}}-{tracking_metric.replace('/', '_')}=" + f"{{{tracking_metric}:.2f}}" + ) + checkpointing = ModelCheckpoint( + checkpoint_dir / "checkpoints", + checkpoint_filename, + tracking_metric, + mode=monitor_mode, + save_last=True, + auto_insert_metric_name=False, + ) + + if args.epochs != -1: + patience = args.patience if args.patience is not None else args.epochs + early_stopping = EarlyStopping(tracking_metric, patience=patience, mode=monitor_mode) + callbacks = [checkpointing, early_stopping] + else: + callbacks = [checkpointing] + + trainer = pl.Trainer( + logger=trainer_logger, + enable_progress_bar=True, + accelerator=args.accelerator, + devices=args.devices, + max_epochs=args.epochs, + callbacks=callbacks, + gradient_clip_val=args.grad_clip, + deterministic=deterministic, + ) + trainer.fit(model, train_loader, val_loader) + + if test_loader is not None: + if isinstance(trainer.strategy, DDPStrategy): + torch.distributed.destroy_process_group() + + best_ckpt_path = trainer.checkpoint_callback.best_model_path + trainer = pl.Trainer( + logger=trainer_logger, + enable_progress_bar=True, + accelerator=args.accelerator, + devices=1, + ) + model = model.load_from_checkpoint(best_ckpt_path) + predss = trainer.predict(model, dataloaders=test_loader) + else: + predss = trainer.predict(dataloaders=test_loader) + + preds = torch.concat(predss, 0) + if model.predictor.n_targets > 1: + preds = preds[..., 0] + preds = preds.numpy() + + evaluate_and_save_predictions( + preds, test_loader, model.metrics[:-1], model_output_dir, args + ) + + best_model_path = checkpointing.best_model_path + model = model.__class__.load_from_checkpoint(best_model_path) + p_model = model_output_dir / "best.pt" + save_model(p_model, model, args.target_columns) + logger.info(f"Best model saved to '{p_model}'") + + if args.remove_checkpoints: + temp_dir.cleanup() + + +def evaluate_and_save_predictions(preds, test_loader, metrics, model_output_dir, args): + if isinstance(test_loader.dataset, MulticomponentDataset): + test_dset = test_loader.dataset.datasets[0] + else: + test_dset = test_loader.dataset + targets = test_dset.Y + mask = torch.from_numpy(np.isfinite(targets)) + targets = np.nan_to_num(targets, nan=0.0) + weights = torch.ones(len(test_dset)) + lt_mask = torch.from_numpy(test_dset.lt_mask) if test_dset.lt_mask[0] is not None else None + gt_mask = torch.from_numpy(test_dset.gt_mask) if test_dset.gt_mask[0] is not None else None + + individual_scores = dict() + for metric in metrics: + individual_scores[metric.alias] = [] + for i, col in enumerate(args.target_columns): + if "multiclass" in args.task_type: + preds_slice = torch.from_numpy(preds[:, i : i + 1, :]) + targets_slice = torch.from_numpy(targets[:, i : i + 1]) + else: + preds_slice = torch.from_numpy(preds[:, i : i + 1]) + targets_slice = torch.from_numpy(targets[:, i : i + 1]) + preds_loss = metric( + preds_slice, + targets_slice, + mask[:, i : i + 1], + weights, + lt_mask[:, i] if lt_mask is not None else None, + gt_mask[:, i] if gt_mask is not None else None, + ) + individual_scores[metric.alias].append(preds_loss) + + logger.info("Test Set results:") + for metric in metrics: + avg_loss = sum(individual_scores[metric.alias]) / len(individual_scores[metric.alias]) + logger.info(f"test/{metric.alias}: {avg_loss}") + + if args.show_individual_scores: + logger.info("Entire Test Set individual results:") + for metric in metrics: + for i, col in enumerate(args.target_columns): + logger.info(f"test/{col}/{metric.alias}: {individual_scores[metric.alias][i]}") + + names = test_loader.dataset.names + if isinstance(test_loader.dataset, MulticomponentDataset): + namess = list(zip(*names)) + else: + namess = [names] + + columns = args.input_columns + args.target_columns + if "multiclass" in args.task_type: + columns = columns + [f"{col}_prob" for col in args.target_columns] + formatted_probability_strings = np.apply_along_axis( + lambda x: ",".join(map(str, x)), 2, preds + ) + predicted_class_labels = preds.argmax(axis=-1) + df_preds = pd.DataFrame( + list(zip(*namess, *predicted_class_labels.T, *formatted_probability_strings.T)), + columns=columns, + ) + else: + df_preds = pd.DataFrame(list(zip(*namess, *preds.T)), columns=columns) + df_preds.to_csv(model_output_dir / "test_predictions.csv", index=False) + + +def main(args): + format_kwargs = dict( + no_header_row=args.no_header_row, + smiles_cols=args.smiles_columns, + rxn_cols=args.reaction_columns, + target_cols=args.target_columns, + ignore_cols=args.ignore_columns, + splits_col=args.splits_column, + weight_col=args.weight_column, + bounded=args.loss_function is not None and "bounded" in args.loss_function, + ) + + featurization_kwargs = dict( + molecule_featurizers=args.molecule_featurizers, + keep_h=args.keep_h, + add_h=args.add_h, + ignore_chirality=args.ignore_chirality, + ) + + splits = build_splits(args, format_kwargs, featurization_kwargs) + + for replicate_idx, (train_data, val_data, test_data) in enumerate(zip(*splits)): + if args.num_replicates == 1: + output_dir = args.output_dir + else: + output_dir = args.output_dir / f"replicate_{replicate_idx}" + + output_dir.mkdir(exist_ok=True, parents=True) + + train_dset, val_dset, test_dset = build_datasets(args, train_data, val_data, test_data) + + if args.save_smiles_splits: + save_smiles_splits(args, output_dir, train_dset, val_dset, test_dset) + + if args.checkpoint or args.model_frzn is not None: + model_paths = find_models(args.checkpoint) + if len(model_paths) > 1: + logger.warning( + "Multiple checkpoint files were loaded, but only the scalers from " + f"{model_paths[0]} are used. It is assumed that all models provided have the " + "same data scalings, meaning they were trained on the same data." + ) + model_path = model_paths[0] if args.checkpoint else args.model_frzn + load_and_use_pretrained_model_scalers(model_path, train_dset, val_dset) + input_transforms = (None, None, None) + output_transform = None + else: + input_transforms = normalize_inputs(train_dset, val_dset, args) + + if "regression" in args.task_type: + output_scaler = train_dset.normalize_targets() + val_dset.normalize_targets(output_scaler) + logger.info( + f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}" + ) + output_transform = UnscaleTransform.from_standard_scaler(output_scaler) + else: + output_transform = None + + if not args.no_cache: + train_dset.cache = True + val_dset.cache = True + + train_loader = build_dataloader( + train_dset, + args.batch_size, + args.num_workers, + class_balance=args.class_balance, + seed=args.data_seed, + ) + if args.class_balance: + logger.debug( + f"With `--class-balance`, effective train size = {len(train_loader.sampler)}" + ) + val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False) + if test_dset is not None: + test_loader = build_dataloader( + test_dset, args.batch_size, args.num_workers, shuffle=False + ) + else: + test_loader = None + + train_model( + args, + train_loader, + val_loader, + test_loader, + output_dir, + output_transform, + input_transforms, + ) + + +if __name__ == "__main__": + # TODO: update this old code or remove it. + parser = ArgumentParser() + parser = TrainSubcommand.add_args(parser) + + logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True) + args = parser.parse_args() + TrainSubcommand.func(args) diff --git a/chemprop-updated/chemprop/cli/utils/__init__.py b/chemprop-updated/chemprop/cli/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fdd239a2a06724abe893c0913cd079addab26ea6 --- /dev/null +++ b/chemprop-updated/chemprop/cli/utils/__init__.py @@ -0,0 +1,30 @@ +from .actions import LookupAction +from .args import bounded +from .command import Subcommand +from .parsing import ( + build_data_from_files, + get_column_names, + make_datapoints, + make_dataset, + parse_indices, +) +from .utils import _pop_attr, _pop_attr_d, pop_attr + +__all__ = [ + "bounded", + "LookupAction", + "Subcommand", + "build_data_from_files", + "make_datapoints", + "make_dataset", + "get_column_names", + "parse_indices", + "actions", + "args", + "command", + "parsing", + "utils", + "pop_attr", + "_pop_attr", + "_pop_attr_d", +] diff --git a/chemprop-updated/chemprop/cli/utils/actions.py b/chemprop-updated/chemprop/cli/utils/actions.py new file mode 100644 index 0000000000000000000000000000000000000000..23e870f37b638499235ddccba0f72355efc3b7c7 --- /dev/null +++ b/chemprop-updated/chemprop/cli/utils/actions.py @@ -0,0 +1,19 @@ +from argparse import _StoreAction +from typing import Any, Mapping + + +def LookupAction(obj: Mapping[str, Any]): + class LookupAction_(_StoreAction): + def __init__(self, option_strings, dest, default=None, choices=None, **kwargs): + if default not in obj.keys() and default is not None: + raise ValueError( + f"Invalid value for arg 'default': '{default}'. " + f"Expected one of {tuple(obj.keys())}" + ) + + kwargs["choices"] = choices if choices is not None else obj.keys() + kwargs["default"] = default + + super().__init__(option_strings, dest, **kwargs) + + return LookupAction_ diff --git a/chemprop-updated/chemprop/cli/utils/args.py b/chemprop-updated/chemprop/cli/utils/args.py new file mode 100644 index 0000000000000000000000000000000000000000..5c6f29e3cd48a39cda6555f6a35a133412df2dd2 --- /dev/null +++ b/chemprop-updated/chemprop/cli/utils/args.py @@ -0,0 +1,34 @@ +import functools + +__all__ = ["bounded"] + + +def bounded(lo: float | None = None, hi: float | None = None): + if lo is None and hi is None: + raise ValueError("No bounds provided!") + + def decorator(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + x = f(*args, **kwargs) + + if (lo is not None and hi is not None) and not lo <= x <= hi: + raise ValueError(f"Parsed value outside of range [{lo}, {hi}]! got: {x}") + if hi is not None and x > hi: + raise ValueError(f"Parsed value below {hi}! got: {x}") + if lo is not None and x < lo: + raise ValueError(f"Parsed value above {lo}]! got: {x}") + + return x + + return wrapper + + return decorator + + +def uppercase(x: str): + return x.upper() + + +def lowercase(x: str): + return x.lower() diff --git a/chemprop-updated/chemprop/cli/utils/command.py b/chemprop-updated/chemprop/cli/utils/command.py new file mode 100644 index 0000000000000000000000000000000000000000..d9edd0d91855240dade06b5d67ae929339d155fa --- /dev/null +++ b/chemprop-updated/chemprop/cli/utils/command.py @@ -0,0 +1,24 @@ +from abc import ABC, abstractmethod +from argparse import ArgumentParser, Namespace, _SubParsersAction + + +class Subcommand(ABC): + COMMAND: str + HELP: str | None = None + + @classmethod + def add(cls, subparsers: _SubParsersAction, parents) -> ArgumentParser: + parser = subparsers.add_parser(cls.COMMAND, help=cls.HELP, parents=parents) + cls.add_args(parser).set_defaults(func=cls.func) + + return parser + + @classmethod + @abstractmethod + def add_args(cls, parser: ArgumentParser) -> ArgumentParser: + pass + + @classmethod + @abstractmethod + def func(cls, args: Namespace): + pass diff --git a/chemprop-updated/chemprop/cli/utils/parsing.py b/chemprop-updated/chemprop/cli/utils/parsing.py new file mode 100644 index 0000000000000000000000000000000000000000..064dd8614a2e0f717d3152ca5bba8b1a9d95d685 --- /dev/null +++ b/chemprop-updated/chemprop/cli/utils/parsing.py @@ -0,0 +1,457 @@ +import logging +from os import PathLike +from typing import Literal, Mapping, Sequence + +import numpy as np +import pandas as pd + +from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint +from chemprop.data.datasets import MoleculeDataset, ReactionDataset +from chemprop.featurizers.atom import get_multi_hot_atom_featurizer +from chemprop.featurizers.bond import MultiHotBondFeaturizer, RIGRBondFeaturizer +from chemprop.featurizers.molecule import MoleculeFeaturizerRegistry +from chemprop.featurizers.molgraph import ( + CondensedGraphOfReactionFeaturizer, + SimpleMoleculeMolGraphFeaturizer, +) +from chemprop.utils import make_mol + +logger = logging.getLogger(__name__) + + +def parse_csv( + path: PathLike, + smiles_cols: Sequence[str] | None, + rxn_cols: Sequence[str] | None, + target_cols: Sequence[str] | None, + ignore_cols: Sequence[str] | None, + splits_col: str | None, + weight_col: str | None, + bounded: bool = False, + no_header_row: bool = False, +): + df = pd.read_csv(path, header=None if no_header_row else "infer", index_col=False) + + if smiles_cols is not None and rxn_cols is not None: + smiss = df[smiles_cols].T.values.tolist() + rxnss = df[rxn_cols].T.values.tolist() + input_cols = [*smiles_cols, *rxn_cols] + elif smiles_cols is not None and rxn_cols is None: + smiss = df[smiles_cols].T.values.tolist() + rxnss = None + input_cols = smiles_cols + elif smiles_cols is None and rxn_cols is not None: + smiss = None + rxnss = df[rxn_cols].T.values.tolist() + input_cols = rxn_cols + else: + smiss = df.iloc[:, [0]].T.values.tolist() + rxnss = None + input_cols = [df.columns[0]] + + if target_cols is None: + target_cols = list( + column + for column in df.columns + if column + not in set( # if splits or weight is None, df.columns will never have None + input_cols + (ignore_cols or []) + [splits_col] + [weight_col] + ) + ) + + Y = df[target_cols] + weights = None if weight_col is None else df[weight_col].to_numpy(np.single) + + if bounded: + Y = Y.astype(str) + lt_mask = Y.applymap(lambda x: "<" in x).to_numpy() + gt_mask = Y.applymap(lambda x: ">" in x).to_numpy() + Y = Y.applymap(lambda x: x.strip("<").strip(">")).to_numpy(np.single) + else: + Y = Y.to_numpy(np.single) + lt_mask = None + gt_mask = None + + return smiss, rxnss, Y, weights, lt_mask, gt_mask + + +def get_column_names( + path: PathLike, + smiles_cols: Sequence[str] | None, + rxn_cols: Sequence[str] | None, + target_cols: Sequence[str] | None, + ignore_cols: Sequence[str] | None, + splits_col: str | None, + weight_col: str | None, + no_header_row: bool = False, +) -> tuple[list[str], list[str]]: + df_cols = pd.read_csv(path, index_col=False, nrows=0).columns.tolist() + + if no_header_row: + return ["SMILES"], ["pred_" + str(i) for i in range((len(df_cols) - 1))] + + input_cols = (smiles_cols or []) + (rxn_cols or []) + + if len(input_cols) == 0: + input_cols = [df_cols[0]] + + if target_cols is None: + target_cols = list( + column + for column in df_cols + if column + not in set( + input_cols + (ignore_cols or []) + ([splits_col] or []) + ([weight_col] or []) + ) + ) + + return input_cols, target_cols + + +def make_datapoints( + smiss: list[list[str]] | None, + rxnss: list[list[str]] | None, + Y: np.ndarray, + weights: np.ndarray | None, + lt_mask: np.ndarray | None, + gt_mask: np.ndarray | None, + X_d: np.ndarray | None, + V_fss: list[list[np.ndarray] | list[None]] | None, + E_fss: list[list[np.ndarray] | list[None]] | None, + V_dss: list[list[np.ndarray] | list[None]] | None, + molecule_featurizers: list[str] | None, + keep_h: bool, + add_h: bool, + ignore_chirality: bool, +) -> tuple[list[list[MoleculeDatapoint]], list[list[ReactionDatapoint]]]: + """Make the :class:`MoleculeDatapoint`s and :class:`ReactionDatapoint`s for a given + dataset. + + Parameters + ---------- + smiss : list[list[str]] | None + a list of ``j`` lists of ``n`` SMILES strings, where ``j`` is the number of molecules per + datapoint and ``n`` is the number of datapoints. If ``None``, the corresponding list of + :class:`MoleculeDatapoint`\s will be empty. + rxnss : list[list[str]] | None + a list of ``k`` lists of ``n`` reaction SMILES strings, where ``k`` is the number of + reactions per datapoint. If ``None``, the corresponding list of :class:`ReactionDatapoint`\s + will be empty. + Y : np.ndarray + the target values of shape ``n x m``, where ``m`` is the number of targets + weights : np.ndarray | None + the weights of the datapoints to use in the loss function of shape ``n x m``. If ``None``, + the weights all default to 1. + lt_mask : np.ndarray | None + a boolean mask of shape ``n x m`` indicating whether the targets are less than inequality + targets. If ``None``, ``lt_mask`` for all datapoints will be ``None``. + gt_mask : np.ndarray | None + a boolean mask of shape ``n x m`` indicating whether the targets are greater than inequality + targets. If ``None``, ``gt_mask`` for all datapoints will be ``None``. + X_d : np.ndarray | None + the extra descriptors of shape ``n x p``, where ``p`` is the number of extra descriptors. If + ``None``, ``x_d`` for all datapoints will be ``None``. + V_fss : list[list[np.ndarray] | list[None]] | None + a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x q_j``, where ``v_jn`` is + the number of atoms in the j-th molecule of the n-th datapoint and ``q_j`` is the number of + extra atom features used for the j-th molecules. Any of the ``j`` lists can be a list of + None values if the corresponding component does not use extra atom features. If ``None``, + ``V_f`` for all datapoints will be ``None``. + E_fss : list[list[np.ndarray] | list[None]] | None + a list of ``j`` lists of ``n`` np.ndarrays each of shape ``e_jn x r_j``, where ``e_jn`` is + the number of bonds in the j-th molecule of the n-th datapoint and ``r_j`` is the number of + extra bond features used for the j-th molecules. Any of the ``j`` lists can be a list of + None values if the corresponding component does not use extra bond features. If ``None``, + ``E_f`` for all datapoints will be ``None``. + V_dss : list[list[np.ndarray] | list[None]] | None + a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x s_j``, where ``s_j`` is + the number of extra atom descriptors used for the j-th molecules. Any of the ``j`` lists can + be a list of None values if the corresponding component does not use extra atom features. If + ``None``, ``V_d`` for all datapoints will be ``None``. + molecule_featurizers : list[str] | None + a list of molecule featurizer names to generate additional molecule features to use as extra + descriptors. If there are multiple molecules per datapoint, the featurizers will be applied + to each molecule and concatenated. Note that a :code:`ReactionDatapoint` has two + RDKit :class:`~rdkit.Chem.Mol` objects, reactant(s) and product(s). Each + ``molecule_featurizer`` will be applied to both of these objects. + keep_h : bool + whether to keep hydrogen atoms + add_h : bool + whether to add hydrogen atoms + ignore_chirality : bool + whether to ignore chirality information + + Returns + ------- + list[list[MoleculeDatapoint]] + a list of ``j`` lists of ``n`` :class:`MoleculeDatapoint`\s + list[list[ReactionDatapoint]] + a list of ``k`` lists of ``n`` :class:`ReactionDatapoint`\s + .. note:: + either ``j`` or ``k`` may be 0, in which case the corresponding list will be empty. + + Raises + ------ + ValueError + if both ``smiss`` and ``rxnss`` are ``None``. + if ``smiss`` and ``rxnss`` are both given and have different lengths. + """ + if smiss is None and rxnss is None: + raise ValueError("args 'smiss' and 'rnxss' were both `None`!") + elif rxnss is None: + N = len(smiss[0]) + rxnss = [] + elif smiss is None: + N = len(rxnss[0]) + smiss = [] + elif len(smiss[0]) != len(rxnss[0]): + raise ValueError( + f"args 'smiss' and 'rxnss' must have same length! got {len(smiss[0])} and {len(rxnss[0])}" + ) + else: + N = len(smiss[0]) + + if len(smiss) > 0: + molss = [[make_mol(smi, keep_h, add_h, ignore_chirality) for smi in smis] for smis in smiss] + if len(rxnss) > 0: + rctss = [ + [ + make_mol( + f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi, keep_h, add_h, ignore_chirality + ) + for rct_smi, agt_smi, _ in (rxn.split(">") for rxn in rxns) + ] + for rxns in rxnss + ] + pdtss = [ + [ + make_mol(pdt_smi, keep_h, add_h, ignore_chirality) + for _, _, pdt_smi in (rxn.split(">") for rxn in rxns) + ] + for rxns in rxnss + ] + + weights = np.ones(N, dtype=np.single) if weights is None else weights + gt_mask = [None] * N if gt_mask is None else gt_mask + lt_mask = [None] * N if lt_mask is None else lt_mask + + n_mols = len(smiss) if smiss else 0 + V_fss = [[None] * N] * n_mols if V_fss is None else V_fss + E_fss = [[None] * N] * n_mols if E_fss is None else E_fss + V_dss = [[None] * N] * n_mols if V_dss is None else V_dss + + if X_d is None and molecule_featurizers is None: + X_d = [None] * N + elif molecule_featurizers is None: + pass + else: + molecule_featurizers = [MoleculeFeaturizerRegistry[mf]() for mf in molecule_featurizers] + + if len(smiss) > 0: + mol_descriptors = np.hstack( + [ + np.vstack([np.hstack([mf(mol) for mf in molecule_featurizers]) for mol in mols]) + for mols in molss + ] + ) + if X_d is None: + X_d = mol_descriptors + else: + X_d = np.hstack([X_d, mol_descriptors]) + + if len(rxnss) > 0: + rct_pdt_descriptors = np.hstack( + [ + np.vstack( + [ + np.hstack( + [mf(mol) for mf in molecule_featurizers for mol in (rct, pdt)] + ) + for rct, pdt in zip(rcts, pdts) + ] + ) + for rcts, pdts in zip(rctss, pdtss) + ] + ) + if X_d is None: + X_d = rct_pdt_descriptors + else: + X_d = np.hstack([X_d, rct_pdt_descriptors]) + + mol_data = [ + [ + MoleculeDatapoint( + mol=molss[mol_idx][i], + name=smis[i], + y=Y[i], + weight=weights[i], + gt_mask=gt_mask[i], + lt_mask=lt_mask[i], + x_d=X_d[i], + x_phase=None, + V_f=V_fss[mol_idx][i], + E_f=E_fss[mol_idx][i], + V_d=V_dss[mol_idx][i], + ) + for i in range(N) + ] + for mol_idx, smis in enumerate(smiss) + ] + rxn_data = [ + [ + ReactionDatapoint( + rct=rctss[rxn_idx][i], + pdt=pdtss[rxn_idx][i], + name=rxns[i], + y=Y[i], + weight=weights[i], + gt_mask=gt_mask[i], + lt_mask=lt_mask[i], + x_d=X_d[i], + x_phase=None, + ) + for i in range(N) + ] + for rxn_idx, rxns in enumerate(rxnss) + ] + + return mol_data, rxn_data + + +def build_data_from_files( + p_data: PathLike, + no_header_row: bool, + smiles_cols: Sequence[str] | None, + rxn_cols: Sequence[str] | None, + target_cols: Sequence[str] | None, + ignore_cols: Sequence[str] | None, + splits_col: str | None, + weight_col: str | None, + bounded: bool, + p_descriptors: PathLike, + p_atom_feats: dict[int, PathLike], + p_bond_feats: dict[int, PathLike], + p_atom_descs: dict[int, PathLike], + **featurization_kwargs: Mapping, +) -> list[list[MoleculeDatapoint] | list[ReactionDatapoint]]: + smiss, rxnss, Y, weights, lt_mask, gt_mask = parse_csv( + p_data, + smiles_cols, + rxn_cols, + target_cols, + ignore_cols, + splits_col, + weight_col, + bounded, + no_header_row, + ) + n_molecules = len(smiss) if smiss is not None else 0 + n_datapoints = len(Y) + + X_ds = load_input_feats_and_descs(p_descriptors, None, None, feat_desc="X_d") + V_fss = load_input_feats_and_descs(p_atom_feats, n_molecules, n_datapoints, feat_desc="V_f") + E_fss = load_input_feats_and_descs(p_bond_feats, n_molecules, n_datapoints, feat_desc="E_f") + V_dss = load_input_feats_and_descs(p_atom_descs, n_molecules, n_datapoints, feat_desc="V_d") + + mol_data, rxn_data = make_datapoints( + smiss, + rxnss, + Y, + weights, + lt_mask, + gt_mask, + X_ds, + V_fss, + E_fss, + V_dss, + **featurization_kwargs, + ) + + return mol_data + rxn_data + + +def load_input_feats_and_descs( + paths: dict[int, PathLike] | PathLike, + n_molecules: int | None, + n_datapoints: int | None, + feat_desc: str, +): + if paths is None: + return None + + match feat_desc: + case "X_d": + path = paths + loaded_feature = np.load(path) + features = loaded_feature["arr_0"] + + case _: + for index in paths: + if index >= n_molecules: + raise ValueError( + f"For {n_molecules} molecules, atom/bond features/descriptors can only be " + f"specified for indices 0-{n_molecules - 1}! Got index {index}." + ) + + features = [] + for idx in range(n_molecules): + path = paths.get(idx, None) + + if path is not None: + loaded_feature = np.load(path) + loaded_feature = [ + loaded_feature[f"arr_{i}"] for i in range(len(loaded_feature)) + ] + else: + loaded_feature = [None] * n_datapoints + + features.append(loaded_feature) + return features + + +def make_dataset( + data: Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint], + reaction_mode: str, + multi_hot_atom_featurizer_mode: Literal["V1", "V2", "ORGANIC", "RIGR"] = "V2", +) -> MoleculeDataset | ReactionDataset: + atom_featurizer = get_multi_hot_atom_featurizer(multi_hot_atom_featurizer_mode) + match multi_hot_atom_featurizer_mode: + case "RIGR": + bond_featurizer = RIGRBondFeaturizer() + case "V1" | "V2" | "ORGANIC": + bond_featurizer = MultiHotBondFeaturizer() + case _: + raise TypeError( + f"Unsupported atom featurizer mode '{multi_hot_atom_featurizer_mode=}'!" + ) + + if isinstance(data[0], MoleculeDatapoint): + extra_atom_fdim = data[0].V_f.shape[1] if data[0].V_f is not None else 0 + extra_bond_fdim = data[0].E_f.shape[1] if data[0].E_f is not None else 0 + featurizer = SimpleMoleculeMolGraphFeaturizer( + atom_featurizer=atom_featurizer, + bond_featurizer=bond_featurizer, + extra_atom_fdim=extra_atom_fdim, + extra_bond_fdim=extra_bond_fdim, + ) + return MoleculeDataset(data, featurizer) + + featurizer = CondensedGraphOfReactionFeaturizer( + mode_=reaction_mode, atom_featurizer=atom_featurizer + ) + + return ReactionDataset(data, featurizer) + + +def parse_indices(idxs): + """Parses a string of indices into a list of integers. e.g. '0,1,2-4' -> [0, 1, 2, 3, 4]""" + if isinstance(idxs, str): + indices = [] + for idx in idxs.split(","): + if "-" in idx: + start, end = map(int, idx.split("-")) + indices.extend(range(start, end + 1)) + else: + indices.append(int(idx)) + return indices + return idxs diff --git a/chemprop-updated/chemprop/cli/utils/utils.py b/chemprop-updated/chemprop/cli/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8f63d224a36065f6a3332b7a3450a5ceb6a05568 --- /dev/null +++ b/chemprop-updated/chemprop/cli/utils/utils.py @@ -0,0 +1,31 @@ +from typing import Any + +__all__ = ["pop_attr"] + + +def pop_attr(o: object, attr: str, *args) -> Any | None: + """like ``pop()`` but for attribute maps""" + match len(args): + case 0: + return _pop_attr(o, attr) + case 1: + return _pop_attr_d(o, attr, args[0]) + case _: + raise TypeError(f"Expected at most 2 arguments! got: {len(args)}") + + +def _pop_attr(o: object, attr: str) -> Any: + val = getattr(o, attr) + delattr(o, attr) + + return val + + +def _pop_attr_d(o: object, attr: str, default: Any | None = None) -> Any | None: + try: + val = getattr(o, attr) + delattr(o, attr) + except AttributeError: + val = default + + return val diff --git a/chemprop-updated/chemprop/conf.py b/chemprop-updated/chemprop/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..f8e3681442d4d4cb553b38d698c4b102bd7c088d --- /dev/null +++ b/chemprop-updated/chemprop/conf.py @@ -0,0 +1,6 @@ +"""Global configuration variables for chemprop""" + +from chemprop.featurizers.molgraph.molecule import SimpleMoleculeMolGraphFeaturizer + +DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM = SimpleMoleculeMolGraphFeaturizer().shape +DEFAULT_HIDDEN_DIM = 300 diff --git a/chemprop-updated/chemprop/data/__init__.py b/chemprop-updated/chemprop/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..843b2a94583f12bf5ca08ac6052f721d67bb2b37 --- /dev/null +++ b/chemprop-updated/chemprop/data/__init__.py @@ -0,0 +1,41 @@ +from .collate import ( + BatchMolGraph, + MulticomponentTrainingBatch, + TrainingBatch, + collate_batch, + collate_multicomponent, +) +from .dataloader import build_dataloader +from .datapoints import MoleculeDatapoint, ReactionDatapoint +from .datasets import ( + Datum, + MoleculeDataset, + MolGraphDataset, + MulticomponentDataset, + ReactionDataset, +) +from .molgraph import MolGraph +from .samplers import ClassBalanceSampler, SeededSampler +from .splitting import SplitType, make_split_indices, split_data_by_indices + +__all__ = [ + "BatchMolGraph", + "TrainingBatch", + "collate_batch", + "MulticomponentTrainingBatch", + "collate_multicomponent", + "build_dataloader", + "MoleculeDatapoint", + "ReactionDatapoint", + "MoleculeDataset", + "ReactionDataset", + "Datum", + "MulticomponentDataset", + "MolGraphDataset", + "MolGraph", + "ClassBalanceSampler", + "SeededSampler", + "SplitType", + "make_split_indices", + "split_data_by_indices", +] diff --git a/chemprop-updated/chemprop/data/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/data/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ba5b174fcc8ba84a8d88c7393bd31d834300826 Binary files /dev/null and b/chemprop-updated/chemprop/data/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/data/__pycache__/data.cpython-37.pyc b/chemprop-updated/chemprop/data/__pycache__/data.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be9b5658c2e6c9bd0459e266f1d96bfe53f7b58a Binary files /dev/null and b/chemprop-updated/chemprop/data/__pycache__/data.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/data/__pycache__/scaffold.cpython-37.pyc b/chemprop-updated/chemprop/data/__pycache__/scaffold.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dcd3f273cf63eb42e4ea8edd23d1016e7912beb Binary files /dev/null and b/chemprop-updated/chemprop/data/__pycache__/scaffold.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/data/__pycache__/scaler.cpython-37.pyc b/chemprop-updated/chemprop/data/__pycache__/scaler.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4178af6742c18488b79dfcec26a804fa465c392 Binary files /dev/null and b/chemprop-updated/chemprop/data/__pycache__/scaler.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/data/__pycache__/utils.cpython-37.pyc b/chemprop-updated/chemprop/data/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c70de06fd22b9a9e8d7894402c32121be965d7b5 Binary files /dev/null and b/chemprop-updated/chemprop/data/__pycache__/utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/data/collate.py b/chemprop-updated/chemprop/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..1147b136a89286efdea12edafe7e17c6136e3a7c --- /dev/null +++ b/chemprop-updated/chemprop/data/collate.py @@ -0,0 +1,123 @@ +from dataclasses import InitVar, dataclass, field +from typing import Iterable, NamedTuple, Sequence + +import numpy as np +import torch +from torch import Tensor + +from chemprop.data.datasets import Datum +from chemprop.data.molgraph import MolGraph + + +@dataclass(repr=False, eq=False, slots=True) +class BatchMolGraph: + """A :class:`BatchMolGraph` represents a batch of individual :class:`MolGraph`\s. + + It has all the attributes of a ``MolGraph`` with the addition of the ``batch`` attribute. This + class is intended for use with data loading, so it uses :obj:`~torch.Tensor`\s to store data + """ + + mgs: InitVar[Sequence[MolGraph]] + """A list of individual :class:`MolGraph`\s to be batched together""" + V: Tensor = field(init=False) + """the atom feature matrix""" + E: Tensor = field(init=False) + """the bond feature matrix""" + edge_index: Tensor = field(init=False) + """an tensor of shape ``2 x E`` containing the edges of the graph in COO format""" + rev_edge_index: Tensor = field(init=False) + """A tensor of shape ``E`` that maps from an edge index to the index of the source of the + reverse edge in the ``edge_index`` attribute.""" + batch: Tensor = field(init=False) + """the index of the parent :class:`MolGraph` in the batched graph""" + names: list[str] = field(init=False) # Add SMILES strings for the batch + + __size: int = field(init=False) + + def __post_init__(self, mgs: Sequence[MolGraph]): + self.__size = len(mgs) + + Vs = [] + Es = [] + edge_indexes = [] + rev_edge_indexes = [] + batch_indexes = [] + self.names = [] + + num_nodes = 0 + num_edges = 0 + for i, mg in enumerate(mgs): + Vs.append(mg.V) + Es.append(mg.E) + edge_indexes.append(mg.edge_index + num_nodes) + rev_edge_indexes.append(mg.rev_edge_index + num_edges) + batch_indexes.append([i] * len(mg.V)) + self.names.append(mg.name) + + num_nodes += mg.V.shape[0] + num_edges += mg.edge_index.shape[1] + + self.V = torch.from_numpy(np.concatenate(Vs)).float() + self.E = torch.from_numpy(np.concatenate(Es)).float() + self.edge_index = torch.from_numpy(np.hstack(edge_indexes)).long() + self.rev_edge_index = torch.from_numpy(np.concatenate(rev_edge_indexes)).long() + self.batch = torch.tensor(np.concatenate(batch_indexes)).long() + + def __len__(self) -> int: + """the number of individual :class:`MolGraph`\s in this batch""" + return self.__size + + def to(self, device: str | torch.device): + self.V = self.V.to(device) + self.E = self.E.to(device) + self.edge_index = self.edge_index.to(device) + self.rev_edge_index = self.rev_edge_index.to(device) + self.batch = self.batch.to(device) + + +class TrainingBatch(NamedTuple): + bmg: BatchMolGraph + V_d: Tensor | None + X_d: Tensor | None + Y: Tensor | None + w: Tensor + lt_mask: Tensor | None + gt_mask: Tensor | None + + +def collate_batch(batch: Iterable[Datum]) -> TrainingBatch: + mgs, V_ds, x_ds, ys, weights, lt_masks, gt_masks = zip(*batch) + + return TrainingBatch( + BatchMolGraph(mgs), + None if V_ds[0] is None else torch.from_numpy(np.concatenate(V_ds)).float(), + None if x_ds[0] is None else torch.from_numpy(np.array(x_ds)).float(), + None if ys[0] is None else torch.from_numpy(np.array(ys)).float(), + torch.tensor(weights, dtype=torch.float).unsqueeze(1), + None if lt_masks[0] is None else torch.from_numpy(np.array(lt_masks)), + None if gt_masks[0] is None else torch.from_numpy(np.array(gt_masks)), + ) + + +class MulticomponentTrainingBatch(NamedTuple): + bmgs: list[BatchMolGraph] + V_ds: list[Tensor | None] + X_d: Tensor | None + Y: Tensor | None + w: Tensor + lt_mask: Tensor | None + gt_mask: Tensor | None + + +def collate_multicomponent(batches: Iterable[Iterable[Datum]]) -> MulticomponentTrainingBatch: + tbs = [collate_batch(batch) for batch in zip(*batches)] + + return MulticomponentTrainingBatch( + [tb.bmg for tb in tbs], + [tb.V_d for tb in tbs], + tbs[0].X_d, + tbs[0].Y, + tbs[0].w, + tbs[0].lt_mask, + tbs[0].gt_mask, + ) diff --git a/chemprop-updated/chemprop/data/dataloader.py b/chemprop-updated/chemprop/data/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..4fc2b2ddee50c70794cb5b60403eecfa5241049f --- /dev/null +++ b/chemprop-updated/chemprop/data/dataloader.py @@ -0,0 +1,71 @@ +import logging + +from torch.utils.data import DataLoader + +from chemprop.data.collate import collate_batch, collate_multicomponent +from chemprop.data.datasets import MoleculeDataset, MulticomponentDataset, ReactionDataset +from chemprop.data.samplers import ClassBalanceSampler, SeededSampler + +logger = logging.getLogger(__name__) + + +def build_dataloader( + dataset: MoleculeDataset | ReactionDataset | MulticomponentDataset, + batch_size: int = 64, + num_workers: int = 0, + class_balance: bool = False, + seed: int | None = None, + shuffle: bool = True, + **kwargs, +): + """Return a :obj:`~torch.utils.data.DataLoader` for :class:`MolGraphDataset`\s + + Parameters + ---------- + dataset : MoleculeDataset | ReactionDataset | MulticomponentDataset + The dataset containing the molecules or reactions to load. + batch_size : int, default=64 + the batch size to load. + num_workers : int, default=0 + the number of workers used to build batches. + class_balance : bool, default=False + Whether to perform class balancing (i.e., use an equal number of positive and negative + molecules). Class balance is only available for single task classification datasets. Set + shuffle to True in order to get a random subset of the larger class. + seed : int, default=None + the random seed to use for shuffling (only used when `shuffle` is `True`). + shuffle : bool, default=False + whether to shuffle the data during sampling. + """ + + if class_balance: + sampler = ClassBalanceSampler(dataset.Y, seed, shuffle) + elif shuffle and seed is not None: + sampler = SeededSampler(len(dataset), seed) + else: + sampler = None + + if isinstance(dataset, MulticomponentDataset): + collate_fn = collate_multicomponent + else: + collate_fn = collate_batch + + if len(dataset) % batch_size == 1: + logger.warning( + f"Dropping last batch of size 1 to avoid issues with batch normalization \ +(dataset size = {len(dataset)}, batch_size = {batch_size})" + ) + drop_last = True + else: + drop_last = False + + return DataLoader( + dataset, + batch_size, + sampler is None and shuffle, + sampler, + num_workers=num_workers, + collate_fn=collate_fn, + drop_last=drop_last, + **kwargs, + ) diff --git a/chemprop-updated/chemprop/data/datapoints.py b/chemprop-updated/chemprop/data/datapoints.py new file mode 100644 index 0000000000000000000000000000000000000000..8c94a9a78c74b1fbb946706248839ff9d1c00e26 --- /dev/null +++ b/chemprop-updated/chemprop/data/datapoints.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import numpy as np +from rdkit.Chem import AllChem as Chem + +from chemprop.featurizers import Featurizer +from chemprop.utils import make_mol + +MoleculeFeaturizer = Featurizer[Chem.Mol, np.ndarray] + + +@dataclass(slots=True) +class _DatapointMixin: + """A mixin class for both molecule- and reaction- and multicomponent-type data""" + + y: np.ndarray | None = None + """the targets for the molecule with unknown targets indicated by `nan`s""" + weight: float = 1.0 + """the weight of this datapoint for the loss calculation.""" + gt_mask: np.ndarray | None = None + """Indicates whether the targets are an inequality regression target of the form `x`""" + x_d: np.ndarray | None = None + """A vector of length ``d_f`` containing additional features (e.g., Morgan fingerprint) that + will be concatenated to the global representation *after* aggregation""" + x_phase: list[float] = None + """A one-hot vector indicating the phase of the data, as used in spectra data.""" + name: str | None = None + """A string identifier for the datapoint.""" + + def __post_init__(self): + NAN_TOKEN = 0 + if self.x_d is not None: + self.x_d[np.isnan(self.x_d)] = NAN_TOKEN + + @property + def t(self) -> int | None: + return len(self.y) if self.y is not None else None + + +@dataclass +class _MoleculeDatapointMixin: + mol: Chem.Mol + """the molecule associated with this datapoint""" + + @classmethod + def from_smi( + cls, + smi: str, + *args, + keep_h: bool = False, + add_h: bool = False, + ignore_chirality: bool = False, + **kwargs, + ) -> _MoleculeDatapointMixin: + mol = make_mol(smi, keep_h, add_h, ignore_chirality) + + kwargs["name"] = smi if "name" not in kwargs else kwargs["name"] + + return cls(mol, *args, **kwargs) + + +@dataclass +class MoleculeDatapoint(_DatapointMixin, _MoleculeDatapointMixin): + """A :class:`MoleculeDatapoint` contains a single molecule and its associated features and targets.""" + + V_f: np.ndarray | None = None + """a numpy array of shape ``V x d_vf``, where ``V`` is the number of atoms in the molecule, and + ``d_vf`` is the number of additional features that will be concatenated to atom-level features + *before* message passing""" + E_f: np.ndarray | None = None + """A numpy array of shape ``E x d_ef``, where ``E`` is the number of bonds in the molecule, and + ``d_ef`` is the number of additional features containing additional features that will be + concatenated to bond-level features *before* message passing""" + V_d: np.ndarray | None = None + """A numpy array of shape ``V x d_vd``, where ``V`` is the number of atoms in the molecule, and + ``d_vd`` is the number of additional descriptors that will be concatenated to atom-level + descriptors *after* message passing""" + + def __post_init__(self): + NAN_TOKEN = 0 + if self.V_f is not None: + self.V_f[np.isnan(self.V_f)] = NAN_TOKEN + if self.E_f is not None: + self.E_f[np.isnan(self.E_f)] = NAN_TOKEN + if self.V_d is not None: + self.V_d[np.isnan(self.V_d)] = NAN_TOKEN + + super().__post_init__() + + def __len__(self) -> int: + return 1 + + +@dataclass +class _ReactionDatapointMixin: + rct: Chem.Mol + """the reactant associated with this datapoint""" + pdt: Chem.Mol + """the product associated with this datapoint""" + + @classmethod + def from_smi( + cls, + rxn_or_smis: str | tuple[str, str], + *args, + keep_h: bool = False, + add_h: bool = False, + ignore_chirality: bool = False, + **kwargs, + ) -> _ReactionDatapointMixin: + match rxn_or_smis: + case str(): + rct_smi, agt_smi, pdt_smi = rxn_or_smis.split(">") + rct_smi = f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi + name = rxn_or_smis + case tuple(): + rct_smi, pdt_smi = rxn_or_smis + name = ">>".join(rxn_or_smis) + case _: + raise TypeError( + "Must provide either a reaction SMARTS string or a tuple of reactant and" + " a product SMILES strings!" + ) + + rct = make_mol(rct_smi, keep_h, add_h, ignore_chirality) + pdt = make_mol(pdt_smi, keep_h, add_h, ignore_chirality) + + kwargs["name"] = name if "name" not in kwargs else kwargs["name"] + + return cls(rct, pdt, *args, **kwargs) + + +@dataclass +class ReactionDatapoint(_DatapointMixin, _ReactionDatapointMixin): + """A :class:`ReactionDatapoint` contains a single reaction and its associated features and targets.""" + + def __post_init__(self): + if self.rct is None: + raise ValueError("Reactant cannot be `None`!") + if self.pdt is None: + raise ValueError("Product cannot be `None`!") + + return super().__post_init__() + + def __len__(self) -> int: + return 2 diff --git a/chemprop-updated/chemprop/data/datasets.py b/chemprop-updated/chemprop/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..1bebad817eeb75c07f9cb01ce32b960a326a3e1a --- /dev/null +++ b/chemprop-updated/chemprop/data/datasets.py @@ -0,0 +1,475 @@ +from dataclasses import dataclass, field +from functools import cached_property +from typing import NamedTuple, TypeAlias + +import numpy as np +from numpy.typing import ArrayLike +from rdkit import Chem +from rdkit.Chem import Mol +from sklearn.preprocessing import StandardScaler +from torch.utils.data import Dataset + +from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint +from chemprop.data.molgraph import MolGraph +from chemprop.featurizers.base import Featurizer +from chemprop.featurizers.molgraph import CGRFeaturizer, SimpleMoleculeMolGraphFeaturizer +from chemprop.featurizers.molgraph.cache import MolGraphCache, MolGraphCacheOnTheFly +from chemprop.types import Rxn + + +class Datum(NamedTuple): + """a singular training data point""" + + mg: MolGraph + V_d: np.ndarray | None + x_d: np.ndarray | None + y: np.ndarray | None + weight: float + lt_mask: np.ndarray | None + gt_mask: np.ndarray | None + + +MolGraphDataset: TypeAlias = Dataset[Datum] + + +class _MolGraphDatasetMixin: + def __len__(self) -> int: + return len(self.data) + + @cached_property + def _Y(self) -> np.ndarray: + """the raw targets of the dataset""" + return np.array([d.y for d in self.data], float) + + @property + def Y(self) -> np.ndarray: + """the (scaled) targets of the dataset""" + return self.__Y + + @Y.setter + def Y(self, Y: ArrayLike): + self._validate_attribute(Y, "targets") + + self.__Y = np.array(Y, float) + + @cached_property + def _X_d(self) -> np.ndarray: + """the raw extra descriptors of the dataset""" + return np.array([d.x_d for d in self.data]) + + @property + def X_d(self) -> np.ndarray: + """the (scaled) extra descriptors of the dataset""" + return self.__X_d + + @X_d.setter + def X_d(self, X_d: ArrayLike): + self._validate_attribute(X_d, "extra descriptors") + + self.__X_d = np.array(X_d) + + @property + def weights(self) -> np.ndarray: + return np.array([d.weight for d in self.data]) + + @property + def gt_mask(self) -> np.ndarray: + return np.array([d.gt_mask for d in self.data]) + + @property + def lt_mask(self) -> np.ndarray: + return np.array([d.lt_mask for d in self.data]) + + @property + def t(self) -> int | None: + return self.data[0].t if len(self.data) > 0 else None + + @property + def d_xd(self) -> int: + """the extra molecule descriptor dimension, if any""" + return 0 if self.X_d[0] is None else self.X_d.shape[1] + + @property + def names(self) -> list[str]: + return [d.name for d in self.data] + + def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler: + """Normalizes the targets of this dataset using a :obj:`StandardScaler` + + The :obj:`StandardScaler` subtracts the mean and divides by the standard deviation for + each task independently. NOTE: This should only be used for regression datasets. + + Returns + ------- + StandardScaler + a scaler fit to the targets. + """ + + if scaler is None: + scaler = StandardScaler().fit(self._Y) + + self.Y = scaler.transform(self._Y) + + return scaler + + def normalize_inputs( + self, key: str = "X_d", scaler: StandardScaler | None = None + ) -> StandardScaler: + VALID_KEYS = {"X_d"} + if key not in VALID_KEYS: + raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}") + + X = self.X_d if self.X_d[0] is not None else None + + if X is None: + return scaler + + if scaler is None: + scaler = StandardScaler().fit(X) + + self.X_d = scaler.transform(X) + + return scaler + + def reset(self): + """Reset the atom and bond features; atom and extra descriptors; and targets of each + datapoint to their initial, unnormalized values.""" + self.__Y = self._Y + self.__X_d = self._X_d + + def _validate_attribute(self, X: np.ndarray, label: str): + if not len(self.data) == len(X): + raise ValueError( + f"number of molecules ({len(self.data)}) and {label} ({len(X)}) " + "must have same length!" + ) + + +@dataclass +class MoleculeDataset(_MolGraphDatasetMixin, MolGraphDataset): + """A :class:`MoleculeDataset` composed of :class:`MoleculeDatapoint`\s + + A :class:`MoleculeDataset` produces featurized data for input to a + :class:`MPNN` model. Typically, data featurization is performed on-the-fly + and parallelized across multiple workers via the :class:`~torch.utils.data + DataLoader` class. However, for small datasets, it may be more efficient to + featurize the data in advance and cache the results. This can be done by + setting ``MoleculeDataset.cache=True``. + + Parameters + ---------- + data : Iterable[MoleculeDatapoint] + the data from which to create a dataset + featurizer : MoleculeFeaturizer + the featurizer with which to generate MolGraphs of the molecules + """ + + data: list[MoleculeDatapoint] + featurizer: Featurizer[Mol, MolGraph] = field(default_factory=SimpleMoleculeMolGraphFeaturizer) + + def __post_init__(self): + if self.data is None: + raise ValueError("Data cannot be None!") + + self.reset() + self.cache = False + + def __getitem__(self, idx: int) -> Datum: + d = self.data[idx] + mg = self.mg_cache[idx] + + # Assign the SMILES string to the MolGraph + mg_with_name = MolGraph( + V=mg.V, + E=mg.E, + edge_index=mg.edge_index, + rev_edge_index=mg.rev_edge_index, + name=d.name # Assign the SMILES string + ) + + return Datum( + mg=mg_with_name, # Use the updated MolGraph + V_d=self.V_ds[idx], + x_d=self.X_d[idx], + y=self.Y[idx], + weight=d.weight, + lt_mask=d.lt_mask, + gt_mask=d.gt_mask, + ) + @property + def cache(self) -> bool: + return self.__cache + + @cache.setter + def cache(self, cache: bool = False): + self.__cache = cache + self._init_cache() + + def _init_cache(self): + """initialize the cache""" + self.mg_cache = (MolGraphCache if self.cache else MolGraphCacheOnTheFly)( + self.mols, self.V_fs, self.E_fs, self.featurizer + ) + + @property + def smiles(self) -> list[str]: + """the SMILES strings associated with the dataset""" + return [Chem.MolToSmiles(d.mol) for d in self.data] + + @property + def mols(self) -> list[Chem.Mol]: + """the molecules associated with the dataset""" + return [d.mol for d in self.data] + + @property + def _V_fs(self) -> list[np.ndarray]: + """the raw atom features of the dataset""" + return [d.V_f for d in self.data] + + @property + def V_fs(self) -> list[np.ndarray]: + """the (scaled) atom descriptors of the dataset""" + return self.__V_fs + + @V_fs.setter + def V_fs(self, V_fs: list[np.ndarray]): + """the (scaled) atom features of the dataset""" + self._validate_attribute(V_fs, "atom features") + + self.__V_fs = V_fs + self._init_cache() + + @property + def _E_fs(self) -> list[np.ndarray]: + """the raw bond features of the dataset""" + return [d.E_f for d in self.data] + + @property + def E_fs(self) -> list[np.ndarray]: + """the (scaled) bond features of the dataset""" + return self.__E_fs + + @E_fs.setter + def E_fs(self, E_fs: list[np.ndarray]): + self._validate_attribute(E_fs, "bond features") + + self.__E_fs = E_fs + self._init_cache() + + @property + def _V_ds(self) -> list[np.ndarray]: + """the raw atom descriptors of the dataset""" + return [d.V_d for d in self.data] + + @property + def V_ds(self) -> list[np.ndarray]: + """the (scaled) atom descriptors of the dataset""" + return self.__V_ds + + @V_ds.setter + def V_ds(self, V_ds: list[np.ndarray]): + self._validate_attribute(V_ds, "atom descriptors") + + self.__V_ds = V_ds + + @property + def d_vf(self) -> int: + """the extra atom feature dimension, if any""" + return 0 if self.V_fs[0] is None else self.V_fs[0].shape[1] + + @property + def d_ef(self) -> int: + """the extra bond feature dimension, if any""" + return 0 if self.E_fs[0] is None else self.E_fs[0].shape[1] + + @property + def d_vd(self) -> int: + """the extra atom descriptor dimension, if any""" + return 0 if self.V_ds[0] is None else self.V_ds[0].shape[1] + + def normalize_inputs( + self, key: str = "X_d", scaler: StandardScaler | None = None + ) -> StandardScaler: + VALID_KEYS = {"X_d", "V_f", "E_f", "V_d"} + + match key: + case "X_d": + X = None if self.d_xd == 0 else self.X_d + case "V_f": + X = None if self.d_vf == 0 else np.concatenate(self.V_fs, axis=0) + case "E_f": + X = None if self.d_ef == 0 else np.concatenate(self.E_fs, axis=0) + case "V_d": + X = None if self.d_vd == 0 else np.concatenate(self.V_ds, axis=0) + case _: + raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}") + + if X is None: + return scaler + + if scaler is None: + scaler = StandardScaler().fit(X) + + match key: + case "X_d": + self.X_d = scaler.transform(X) + case "V_f": + self.V_fs = [scaler.transform(V_f) if V_f.size > 0 else V_f for V_f in self.V_fs] + case "E_f": + self.E_fs = [scaler.transform(E_f) if E_f.size > 0 else E_f for E_f in self.E_fs] + case "V_d": + self.V_ds = [scaler.transform(V_d) if V_d.size > 0 else V_d for V_d in self.V_ds] + case _: + raise RuntimeError("unreachable code reached!") + + return scaler + + def reset(self): + """Reset the atom and bond features; atom and extra descriptors; and targets of each + datapoint to their initial, unnormalized values.""" + super().reset() + self.__V_fs = self._V_fs + self.__E_fs = self._E_fs + self.__V_ds = self._V_ds + + +@dataclass +class ReactionDataset(_MolGraphDatasetMixin, MolGraphDataset): + """A :class:`ReactionDataset` composed of :class:`ReactionDatapoint`\s + + .. note:: + The featurized data provided by this class may be cached, simlar to a + :class:`MoleculeDataset`. To enable the cache, set ``ReactionDataset + cache=True``. + """ + + data: list[ReactionDatapoint] + """the dataset from which to load""" + featurizer: Featurizer[Rxn, MolGraph] = field(default_factory=CGRFeaturizer) + """the featurizer with which to generate MolGraphs of the input""" + + def __post_init__(self): + if self.data is None: + raise ValueError("Data cannot be None!") + + self.reset() + self.cache = False + + @property + def cache(self) -> bool: + return self.__cache + + @cache.setter + def cache(self, cache: bool = False): + self.__cache = cache + self.mg_cache = (MolGraphCache if cache else MolGraphCacheOnTheFly)( + self.mols, [None] * len(self), [None] * len(self), self.featurizer + ) + + def __getitem__(self, idx: int) -> Datum: + d = self.data[idx] + mg = self.mg_cache[idx] + + return Datum(mg, None, self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask) + + @property + def smiles(self) -> list[tuple]: + return [(Chem.MolToSmiles(d.rct), Chem.MolToSmiles(d.pdt)) for d in self.data] + + @property + def mols(self) -> list[Rxn]: + return [(d.rct, d.pdt) for d in self.data] + + @property + def d_vf(self) -> int: + return 0 + + @property + def d_ef(self) -> int: + return 0 + + @property + def d_vd(self) -> int: + return 0 + + +@dataclass(repr=False, eq=False) +class MulticomponentDataset(_MolGraphDatasetMixin, Dataset): + """A :class:`MulticomponentDataset` is a :class:`Dataset` composed of parallel + :class:`MoleculeDatasets` and :class:`ReactionDataset`\s""" + + datasets: list[MoleculeDataset | ReactionDataset] + """the parallel datasets""" + + def __post_init__(self): + sizes = [len(dset) for dset in self.datasets] + if not all(sizes[0] == size for size in sizes[1:]): + raise ValueError(f"Datasets must have all same length! got: {sizes}") + + def __len__(self) -> int: + return len(self.datasets[0]) + + @property + def n_components(self) -> int: + return len(self.datasets) + + def __getitem__(self, idx: int) -> list[Datum]: + return [dset[idx] for dset in self.datasets] + + @property + def smiles(self) -> list[list[str]]: + return list(zip(*[dset.smiles for dset in self.datasets])) + + @property + def names(self) -> list[list[str]]: + return list(zip(*[dset.names for dset in self.datasets])) + + @property + def mols(self) -> list[list[Chem.Mol]]: + return list(zip(*[dset.mols for dset in self.datasets])) + + def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler: + return self.datasets[0].normalize_targets(scaler) + + def normalize_inputs( + self, key: str = "X_d", scaler: list[StandardScaler] | None = None + ) -> list[StandardScaler]: + RXN_VALID_KEYS = {"X_d"} + match scaler: + case None: + return [ + dset.normalize_inputs(key) + if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS + else None + for dset in self.datasets + ] + case _: + assert len(scaler) == len( + self.datasets + ), "Number of scalers must match number of datasets!" + + return [ + dset.normalize_inputs(key, s) + if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS + else None + for dset, s in zip(self.datasets, scaler) + ] + + def reset(self): + return [dset.reset() for dset in self.datasets] + + @property + def d_xd(self) -> list[int]: + return self.datasets[0].d_xd + + @property + def d_vf(self) -> list[int]: + return sum(dset.d_vf for dset in self.datasets) + + @property + def d_ef(self) -> list[int]: + return sum(dset.d_ef for dset in self.datasets) + + @property + def d_vd(self) -> list[int]: + return sum(dset.d_vd for dset in self.datasets) diff --git a/chemprop-updated/chemprop/data/molgraph.py b/chemprop-updated/chemprop/data/molgraph.py new file mode 100644 index 0000000000000000000000000000000000000000..af7025ae72f3a002566ef801b0b8b41b3add8fb8 --- /dev/null +++ b/chemprop-updated/chemprop/data/molgraph.py @@ -0,0 +1,17 @@ +from typing import NamedTuple + +import numpy as np + + +class MolGraph(NamedTuple): + """A :class:`MolGraph` represents the graph featurization of a molecule.""" + + V: np.ndarray + """an array of shape ``V x d_v`` containing the atom features of the molecule""" + E: np.ndarray + """an array of shape ``E x d_e`` containing the bond features of the molecule""" + edge_index: np.ndarray + """an array of shape ``2 x E`` containing the edges of the graph in COO format""" + rev_edge_index: np.ndarray + """A array of shape ``E`` that maps from an edge index to the index of the source of the reverse edge in :attr:`edge_index` attribute.""" + name: str | None = None # Add SMILES string as an optional attribute \ No newline at end of file diff --git a/chemprop-updated/chemprop/data/samplers.py b/chemprop-updated/chemprop/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..8a24c9769ce73fa7c6a853f25899d6a95bc212cb --- /dev/null +++ b/chemprop-updated/chemprop/data/samplers.py @@ -0,0 +1,66 @@ +from itertools import chain +from typing import Iterator, Optional + +import numpy as np +from torch.utils.data import Sampler + + +class SeededSampler(Sampler): + """A :class`SeededSampler` is a class for iterating through a dataset in a randomly seeded + fashion""" + + def __init__(self, N: int, seed: int): + if seed is None: + raise ValueError("arg 'seed' was `None`! A SeededSampler must be seeded!") + + self.idxs = np.arange(N) + self.rg = np.random.default_rng(seed) + + def __iter__(self) -> Iterator[int]: + """an iterator over indices to sample.""" + self.rg.shuffle(self.idxs) + + return iter(self.idxs) + + def __len__(self) -> int: + """the number of indices that will be sampled.""" + return len(self.idxs) + + +class ClassBalanceSampler(Sampler): + """A :class:`ClassBalanceSampler` samples data from a :class:`MolGraphDataset` such that + positive and negative classes are equally sampled + + Parameters + ---------- + dataset : MolGraphDataset + the dataset from which to sample + seed : int + the random seed to use for shuffling (only used when `shuffle` is `True`) + shuffle : bool, default=False + whether to shuffle the data during sampling + """ + + def __init__(self, Y: np.ndarray, seed: Optional[int] = None, shuffle: bool = False): + self.shuffle = shuffle + self.rg = np.random.default_rng(seed) + + idxs = np.arange(len(Y)) + actives = Y.any(1) + + self.pos_idxs = idxs[actives] + self.neg_idxs = idxs[~actives] + + self.length = 2 * min(len(self.pos_idxs), len(self.neg_idxs)) + + def __iter__(self) -> Iterator[int]: + """an iterator over indices to sample.""" + if self.shuffle: + self.rg.shuffle(self.pos_idxs) + self.rg.shuffle(self.neg_idxs) + + return chain(*zip(self.pos_idxs, self.neg_idxs)) + + def __len__(self) -> int: + """the number of indices that will be sampled.""" + return self.length diff --git a/chemprop-updated/chemprop/data/splitting.py b/chemprop-updated/chemprop/data/splitting.py new file mode 100644 index 0000000000000000000000000000000000000000..f4bb1b6f91667634bb21ad9460d9ee6e87286df3 --- /dev/null +++ b/chemprop-updated/chemprop/data/splitting.py @@ -0,0 +1,225 @@ +from collections.abc import Iterable, Sequence +import copy +from enum import auto +import logging + +from astartes import train_test_split, train_val_test_split +from astartes.molecules import train_test_split_molecules, train_val_test_split_molecules +import numpy as np +from rdkit import Chem + +from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint +from chemprop.utils.utils import EnumMapping + +logger = logging.getLogger(__name__) + +Datapoints = Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint] +MulticomponentDatapoints = Sequence[Datapoints] + + +class SplitType(EnumMapping): + SCAFFOLD_BALANCED = auto() + RANDOM_WITH_REPEATED_SMILES = auto() + RANDOM = auto() + KENNARD_STONE = auto() + KMEANS = auto() + + +def make_split_indices( + mols: Sequence[Chem.Mol], + split: SplitType | str = "random", + sizes: tuple[float, float, float] = (0.8, 0.1, 0.1), + seed: int = 0, + num_replicates: int = 1, + num_folds: None = None, +) -> tuple[list[list[int]], ...]: + """Splits data into training, validation, and test splits. + + Parameters + ---------- + mols : Sequence[Chem.Mol] + Sequence of RDKit molecules to use for structure based splitting + split : SplitType | str, optional + Split type, one of ~chemprop.data.utils.SplitType, by default "random" + sizes : tuple[float, float, float], optional + 3-tuple with the proportions of data in the train, validation, and test sets, by default + (0.8, 0.1, 0.1). Set the middle value to 0 for a two way split. + seed : int, optional + The random seed passed to astartes, by default 0 + num_replicates : int, optional + Number of replicates, by default 1 + num_folds : None, optional + This argument was removed in v2.1 - use `num_replicates` instead. + + Returns + ------- + tuple[list[list[int]], ...] + 2- or 3-member tuple containing num_replicates length lists of training, validation, and testing indexes. + + .. important:: + Validation may or may not be present + + Raises + ------ + ValueError + Requested split sizes tuple not of length 3 + ValueError + Unsupported split method requested + """ + if num_folds is not None: + raise RuntimeError("This argument was removed in v2.1 - use `num_replicates` instead.") + if num_replicates == 1: + logger.warning( + "The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)" + ) + if (num_splits := len(sizes)) != 3: + raise ValueError( + f"Specify sizes for train, validation, and test (got {num_splits} values)." + ) + # typically include a validation set + include_val = True + split_fun = train_val_test_split + mol_split_fun = train_val_test_split_molecules + # default sampling arguments for astartes sampler + astartes_kwargs = dict( + train_size=sizes[0], test_size=sizes[2], return_indices=True, random_state=seed + ) + # if no validation set, reassign the splitting functions + if sizes[1] == 0.0: + include_val = False + split_fun = train_test_split + mol_split_fun = train_test_split_molecules + else: + astartes_kwargs["val_size"] = sizes[1] + + n_datapoints = len(mols) + train_replicates, val_replicates, test_replicates = [], [], [] + for _ in range(num_replicates): + train, val, test = None, None, None + match SplitType.get(split): + case SplitType.SCAFFOLD_BALANCED: + mols_without_atommaps = [] + for mol in mols: + copied_mol = copy.deepcopy(mol) + for atom in copied_mol.GetAtoms(): + atom.SetAtomMapNum(0) + mols_without_atommaps.append(copied_mol) + result = mol_split_fun( + np.array(mols_without_atommaps), sampler="scaffold", **astartes_kwargs + ) + train, val, test = _unpack_astartes_result(result, include_val) + + # Use to constrain data with the same smiles go in the same split. + case SplitType.RANDOM_WITH_REPEATED_SMILES: + # get two arrays: one of all the smiles strings, one of just the unique + all_smiles = np.array([Chem.MolToSmiles(mol) for mol in mols]) + unique_smiles = np.unique(all_smiles) + + # save a mapping of smiles -> all the indices that it appeared at + smiles_indices = {} + for smiles in unique_smiles: + smiles_indices[smiles] = np.where(all_smiles == smiles)[0].tolist() + + # randomly split the unique smiles + result = split_fun( + np.arange(len(unique_smiles)), sampler="random", **astartes_kwargs + ) + train_idxs, val_idxs, test_idxs = _unpack_astartes_result(result, include_val) + + # convert these to the 'actual' indices from the original list using the dict we made + train = sum((smiles_indices[unique_smiles[i]] for i in train_idxs), []) + val = sum((smiles_indices[unique_smiles[j]] for j in val_idxs), []) + test = sum((smiles_indices[unique_smiles[k]] for k in test_idxs), []) + + case SplitType.RANDOM: + result = split_fun(np.arange(n_datapoints), sampler="random", **astartes_kwargs) + train, val, test = _unpack_astartes_result(result, include_val) + + case SplitType.KENNARD_STONE: + result = mol_split_fun( + np.array(mols), + sampler="kennard_stone", + hopts=dict(metric="jaccard"), + fingerprint="morgan_fingerprint", + fprints_hopts=dict(n_bits=2048), + **astartes_kwargs, + ) + train, val, test = _unpack_astartes_result(result, include_val) + + case SplitType.KMEANS: + result = mol_split_fun( + np.array(mols), + sampler="kmeans", + hopts=dict(metric="jaccard"), + fingerprint="morgan_fingerprint", + fprints_hopts=dict(n_bits=2048), + **astartes_kwargs, + ) + train, val, test = _unpack_astartes_result(result, include_val) + + case _: + raise RuntimeError("Unreachable code reached!") + train_replicates.append(train) + val_replicates.append(val) + test_replicates.append(test) + astartes_kwargs["random_state"] += 1 + return train_replicates, val_replicates, test_replicates + + +def _unpack_astartes_result( + result: tuple, include_val: bool +) -> tuple[list[int], list[int], list[int]]: + """Helper function to partition input data based on output of astartes sampler + + Parameters + ----------- + result: tuple + Output from call to astartes containing the split indices + include_val: bool + True if a validation set is included, False otherwise. + + Returns + --------- + train: list[int] + val: list[int] + .. important:: + validation possibly empty + test: list[int] + """ + train_idxs, val_idxs, test_idxs = [], [], [] + # astartes returns a set of lists containing the data, clusters (if applicable) + # and indices (always last), so we pull out the indices + if include_val: + train_idxs, val_idxs, test_idxs = result[-3], result[-2], result[-1] + else: + train_idxs, test_idxs = result[-2], result[-1] + return list(train_idxs), list(val_idxs), list(test_idxs) + + +def split_data_by_indices( + data: Datapoints | MulticomponentDatapoints, + train_indices: Iterable[Iterable[int]] | None = None, + val_indices: Iterable[Iterable[int]] | None = None, + test_indices: Iterable[Iterable[int]] | None = None, +): + """Splits data into training, validation, and test groups based on split indices given.""" + + train_data = _splitter_helper(data, train_indices) + val_data = _splitter_helper(data, val_indices) + test_data = _splitter_helper(data, test_indices) + + return train_data, val_data, test_data + + +def _splitter_helper(data, indices): + if indices is None: + return None + + if isinstance(data[0], (MoleculeDatapoint, ReactionDatapoint)): + datapoints = data + idxss = indices + return [[datapoints[idx] for idx in idxs] for idxs in idxss] + else: + datapointss = data + idxss = indices + return [[[datapoints[idx] for idx in idxs] for datapoints in datapointss] for idxs in idxss] diff --git a/chemprop-updated/chemprop/exceptions.py b/chemprop-updated/chemprop/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..29229ca41753dbe886312ff91850e75e7d69a556 --- /dev/null +++ b/chemprop-updated/chemprop/exceptions.py @@ -0,0 +1,12 @@ +from typing import Iterable + +from chemprop.utils import pretty_shape + + +class InvalidShapeError(ValueError): + def __init__(self, var_name: str, received: Iterable[int], expected: Iterable[int]): + message = ( + f"arg '{var_name}' has incorrect shape! " + f"got: `{pretty_shape(received)}`. expected: `{pretty_shape(expected)}`" + ) + super().__init__(message) diff --git a/chemprop-updated/chemprop/features/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/features/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba2b45b0e6e63a3a213f6d022537d2ad7497dc17 Binary files /dev/null and b/chemprop-updated/chemprop/features/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/features/__pycache__/features_generators.cpython-37.pyc b/chemprop-updated/chemprop/features/__pycache__/features_generators.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8c1bbf65c9224d7bbcf025ea68df7e932d49ece Binary files /dev/null and b/chemprop-updated/chemprop/features/__pycache__/features_generators.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/features/__pycache__/featurization.cpython-37.pyc b/chemprop-updated/chemprop/features/__pycache__/featurization.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cb4f7da9e474aae8a68f5e1ef355d2843a3914b Binary files /dev/null and b/chemprop-updated/chemprop/features/__pycache__/featurization.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/features/__pycache__/utils.cpython-37.pyc b/chemprop-updated/chemprop/features/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5759d70f4ca7a6203e1ac4eeddc98cc02306952e Binary files /dev/null and b/chemprop-updated/chemprop/features/__pycache__/utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/featurizers/__init__.py b/chemprop-updated/chemprop/featurizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a266fd820ac640d47a22b5a68a6afcb2ab7a2d9c --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/__init__.py @@ -0,0 +1,52 @@ +from .atom import AtomFeatureMode, MultiHotAtomFeaturizer, get_multi_hot_atom_featurizer +from .base import Featurizer, GraphFeaturizer, S, T, VectorFeaturizer +from .bond import MultiHotBondFeaturizer +from .molecule import ( + BinaryFeaturizerMixin, + CountFeaturizerMixin, + MoleculeFeaturizerRegistry, + MorganBinaryFeaturizer, + MorganCountFeaturizer, + MorganFeaturizerMixin, + RDKit2DFeaturizer, + V1RDKit2DFeaturizer, + V1RDKit2DNormalizedFeaturizer, +) +from .molgraph import ( + CGRFeaturizer, + CondensedGraphOfReactionFeaturizer, + MolGraphCache, + MolGraphCacheFacade, + MolGraphCacheOnTheFly, + RxnMode, + SimpleMoleculeMolGraphFeaturizer, +) + +__all__ = [ + "Featurizer", + "S", + "T", + "VectorFeaturizer", + "GraphFeaturizer", + "MultiHotAtomFeaturizer", + "AtomFeatureMode", + "get_multi_hot_atom_featurizer", + "MultiHotBondFeaturizer", + "MolGraphCacheFacade", + "MolGraphCache", + "MolGraphCacheOnTheFly", + "SimpleMoleculeMolGraphFeaturizer", + "CondensedGraphOfReactionFeaturizer", + "CGRFeaturizer", + "RxnMode", + "MoleculeFeaturizer", + "MorganFeaturizerMixin", + "BinaryFeaturizerMixin", + "CountFeaturizerMixin", + "MorganBinaryFeaturizer", + "MorganCountFeaturizer", + "RDKit2DFeaturizer", + "MoleculeFeaturizerRegistry", + "V1RDKit2DFeaturizer", + "V1RDKit2DNormalizedFeaturizer", +] diff --git a/chemprop-updated/chemprop/featurizers/atom.py b/chemprop-updated/chemprop/featurizers/atom.py new file mode 100644 index 0000000000000000000000000000000000000000..c224423f1a4f311bd371f9a2e83666a138f0659d --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/atom.py @@ -0,0 +1,281 @@ +from enum import auto +from typing import Sequence + +import numpy as np +from rdkit.Chem.rdchem import Atom, HybridizationType + +from chemprop.featurizers.base import VectorFeaturizer +from chemprop.utils.utils import EnumMapping + + +class MultiHotAtomFeaturizer(VectorFeaturizer[Atom]): + """A :class:`MultiHotAtomFeaturizer` uses a multi-hot encoding to featurize atoms. + + .. seealso:: + The class provides three default parameterization schemes: + + * :meth:`MultiHotAtomFeaturizer.v1` + * :meth:`MultiHotAtomFeaturizer.v2` + * :meth:`MultiHotAtomFeaturizer.organic` + + The generated atom features are ordered as follows: + * atomic number + * degree + * formal charge + * chiral tag + * number of hydrogens + * hybridization + * aromaticity + * mass + + .. important:: + Each feature, except for aromaticity and mass, includes a pad for unknown values. + + Parameters + ---------- + atomic_nums : Sequence[int] + the choices for atom type denoted by atomic number. Ex: ``[4, 5, 6]`` for C, N and O. + degrees : Sequence[int] + the choices for number of bonds an atom is engaged in. + formal_charges : Sequence[int] + the choices for integer electronic charge assigned to an atom. + chiral_tags : Sequence[int] + the choices for an atom's chiral tag. See :class:`rdkit.Chem.rdchem.ChiralType` for possible integer values. + num_Hs : Sequence[int] + the choices for number of bonded hydrogen atoms. + hybridizations : Sequence[int] + the choices for an atom’s hybridization type. See :class:`rdkit.Chem.rdchem.HybridizationType` for possible integer values. + """ + + def __init__( + self, + atomic_nums: Sequence[int], + degrees: Sequence[int], + formal_charges: Sequence[int], + chiral_tags: Sequence[int], + num_Hs: Sequence[int], + hybridizations: Sequence[int], + ): + self.atomic_nums = {j: i for i, j in enumerate(atomic_nums)} + self.degrees = {i: i for i in degrees} + self.formal_charges = {j: i for i, j in enumerate(formal_charges)} + self.chiral_tags = {i: i for i in chiral_tags} + self.num_Hs = {i: i for i in num_Hs} + self.hybridizations = {ht: i for i, ht in enumerate(hybridizations)} + + self._subfeats: list[dict] = [ + self.atomic_nums, + self.degrees, + self.formal_charges, + self.chiral_tags, + self.num_Hs, + self.hybridizations, + ] + subfeat_sizes = [ + 1 + len(self.atomic_nums), + 1 + len(self.degrees), + 1 + len(self.formal_charges), + 1 + len(self.chiral_tags), + 1 + len(self.num_Hs), + 1 + len(self.hybridizations), + 1, + 1, + ] + self.__size = sum(subfeat_sizes) + + def __len__(self) -> int: + return self.__size + + def __call__(self, a: Atom | None) -> np.ndarray: + x = np.zeros(self.__size) + + if a is None: + return x + + feats = [ + a.GetAtomicNum(), + a.GetTotalDegree(), + a.GetFormalCharge(), + int(a.GetChiralTag()), + int(a.GetTotalNumHs()), + a.GetHybridization(), + ] + i = 0 + for feat, choices in zip(feats, self._subfeats): + j = choices.get(feat, len(choices)) + x[i + j] = 1 + i += len(choices) + 1 + x[i] = int(a.GetIsAromatic()) + x[i + 1] = 0.01 * a.GetMass() + + return x + + def num_only(self, a: Atom) -> np.ndarray: + """featurize the atom by setting only the atomic number bit""" + x = np.zeros(len(self)) + + if a is None: + return x + + i = self.atomic_nums.get(a.GetAtomicNum(), len(self.atomic_nums)) + x[i] = 1 + + return x + + @classmethod + def v1(cls, max_atomic_num: int = 100): + """The original implementation used in Chemprop V1 [1]_, [2]_. + + Parameters + ---------- + max_atomic_num : int, default=100 + Include a bit for all atomic numbers in the interval :math:`[1, \mathtt{max\_atomic\_num}]` + + References + ----------- + .. [1] Yang, K.; Swanson, K.; Jin, W.; Coley, C.; Eiden, P.; Gao, H.; Guzman-Perez, A.; Hopper, T.; + Kelley, B.; Mathea, M.; Palmer, A. "Analyzing Learned Molecular Representations for Property Prediction." + J. Chem. Inf. Model. 2019, 59 (8), 3370–3388. https://doi.org/10.1021/acs.jcim.9b00237 + .. [2] Heid, E.; Greenman, K.P.; Chung, Y.; Li, S.C.; Graff, D.E.; Vermeire, F.H.; Wu, H.; Green, W.H.; McGill, + C.J. "Chemprop: A machine learning package for chemical property prediction." J. Chem. Inf. Model. 2024, + 64 (1), 9–17. https://doi.org/10.1021/acs.jcim.3c01250 + """ + + return cls( + atomic_nums=list(range(1, max_atomic_num + 1)), + degrees=list(range(6)), + formal_charges=[-1, -2, 1, 2, 0], + chiral_tags=list(range(4)), + num_Hs=list(range(5)), + hybridizations=[ + HybridizationType.SP, + HybridizationType.SP2, + HybridizationType.SP3, + HybridizationType.SP3D, + HybridizationType.SP3D2, + ], + ) + + @classmethod + def v2(cls): + """An implementation that includes an atom type bit for all elements in the first four rows of the periodic table plus iodine.""" + + return cls( + atomic_nums=list(range(1, 37)) + [53], + degrees=list(range(6)), + formal_charges=[-1, -2, 1, 2, 0], + chiral_tags=list(range(4)), + num_Hs=list(range(5)), + hybridizations=[ + HybridizationType.S, + HybridizationType.SP, + HybridizationType.SP2, + HybridizationType.SP2D, + HybridizationType.SP3, + HybridizationType.SP3D, + HybridizationType.SP3D2, + ], + ) + + @classmethod + def organic(cls): + r"""A specific parameterization intended for use with organic or drug-like molecules. + + This parameterization features: + 1. includes an atomic number bit only for H, B, C, N, O, F, Si, P, S, Cl, Br, and I atoms + 2. a hybridization bit for :math:`s, sp, sp^2` and :math:`sp^3` hybridizations. + """ + + return cls( + atomic_nums=[1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53], + degrees=list(range(6)), + formal_charges=[-1, -2, 1, 2, 0], + chiral_tags=list(range(4)), + num_Hs=list(range(5)), + hybridizations=[ + HybridizationType.S, + HybridizationType.SP, + HybridizationType.SP2, + HybridizationType.SP3, + ], + ) + + +class RIGRAtomFeaturizer(VectorFeaturizer[Atom]): + """A :class:`RIGRAtomFeaturizer` uses a multi-hot encoding to featurize atoms using resonance-invariant features. + + The generated atom features are ordered as follows: + * atomic number + * degree + * number of hydrogens + * mass + """ + + def __init__( + self, + atomic_nums: Sequence[int] | None = None, + degrees: Sequence[int] | None = None, + num_Hs: Sequence[int] | None = None, + ): + self.atomic_nums = {j: i for i, j in enumerate(atomic_nums or list(range(1, 37)) + [53])} + self.degrees = {i: i for i in (degrees or list(range(6)))} + self.num_Hs = {i: i for i in (num_Hs or list(range(5)))} + + self._subfeats: list[dict] = [self.atomic_nums, self.degrees, self.num_Hs] + subfeat_sizes = [1 + len(self.atomic_nums), 1 + len(self.degrees), 1 + len(self.num_Hs), 1] + self.__size = sum(subfeat_sizes) + + def __len__(self) -> int: + return self.__size + + def __call__(self, a: Atom | None) -> np.ndarray: + x = np.zeros(self.__size) + + if a is None: + return x + + feats = [a.GetAtomicNum(), a.GetTotalDegree(), int(a.GetTotalNumHs())] + i = 0 + for feat, choices in zip(feats, self._subfeats): + j = choices.get(feat, len(choices)) + x[i + j] = 1 + i += len(choices) + 1 + x[i] = 0.01 * a.GetMass() # scaled to about the same range as other features + + return x + + def num_only(self, a: Atom) -> np.ndarray: + """featurize the atom by setting only the atomic number bit""" + x = np.zeros(len(self)) + + if a is None: + return x + + i = self.atomic_nums.get(a.GetAtomicNum(), len(self.atomic_nums)) + x[i] = 1 + + return x + + +class AtomFeatureMode(EnumMapping): + """The mode of an atom is used for featurization into a `MolGraph`""" + + V1 = auto() + V2 = auto() + ORGANIC = auto() + RIGR = auto() + + +def get_multi_hot_atom_featurizer(mode: str | AtomFeatureMode) -> MultiHotAtomFeaturizer: + """Build the corresponding multi-hot atom featurizer.""" + match AtomFeatureMode.get(mode): + case AtomFeatureMode.V1: + return MultiHotAtomFeaturizer.v1() + case AtomFeatureMode.V2: + return MultiHotAtomFeaturizer.v2() + case AtomFeatureMode.ORGANIC: + return MultiHotAtomFeaturizer.organic() + case AtomFeatureMode.RIGR: + return RIGRAtomFeaturizer() + case _: + raise RuntimeError("unreachable code reached!") diff --git a/chemprop-updated/chemprop/featurizers/base.py b/chemprop-updated/chemprop/featurizers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..29b876bd8751e13ac151c43f3a7d8b1d42d4a831 --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/base.py @@ -0,0 +1,30 @@ +from abc import abstractmethod +from collections.abc import Sized +from typing import Generic, TypeVar + +import numpy as np + +from chemprop.data.molgraph import MolGraph + +S = TypeVar("S") +T = TypeVar("T") + + +class Featurizer(Generic[S, T]): + """An :class:`Featurizer` featurizes inputs type ``S`` into outputs of + type ``T``.""" + + @abstractmethod + def __call__(self, input: S, *args, **kwargs) -> T: + """featurize an input""" + + +class VectorFeaturizer(Featurizer[S, np.ndarray], Sized): + ... + + +class GraphFeaturizer(Featurizer[S, MolGraph]): + @property + @abstractmethod + def shape(self) -> tuple[int, int]: + ... diff --git a/chemprop-updated/chemprop/featurizers/bond.py b/chemprop-updated/chemprop/featurizers/bond.py new file mode 100644 index 0000000000000000000000000000000000000000..c604b89d1c7b7d991fac2ebbce9f866cc1b1603c --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/bond.py @@ -0,0 +1,122 @@ +from typing import Sequence + +import numpy as np +from rdkit.Chem.rdchem import Bond, BondType + +from chemprop.featurizers.base import VectorFeaturizer + + +class MultiHotBondFeaturizer(VectorFeaturizer[Bond]): + """A :class:`MultiHotBondFeaturizer` feauturizes bonds based on the following attributes: + + * ``null``-ity (i.e., is the bond ``None``?) + * bond type + * conjugated? + * in ring? + * stereochemistry + + The feature vectors produced by this featurizer have the following (general) signature: + + +---------------------+-----------------+--------------+ + | slice [start, stop) | subfeature | unknown pad? | + +=====================+=================+==============+ + | 0-1 | null? | N | + +---------------------+-----------------+--------------+ + | 1-5 | bond type | N | + +---------------------+-----------------+--------------+ + | 5-6 | conjugated? | N | + +---------------------+-----------------+--------------+ + | 6-8 | in ring? | N | + +---------------------+-----------------+--------------+ + | 7-14 | stereochemistry | Y | + +---------------------+-----------------+--------------+ + + **NOTE**: the above signature only applies for the default arguments, as the bond type and + sterochemistry slices can increase in size depending on the input arguments. + + Parameters + ---------- + bond_types : Sequence[BondType] | None, default=[SINGLE, DOUBLE, TRIPLE, AROMATIC] + the known bond types + stereos : Sequence[int] | None, default=[0, 1, 2, 3, 4, 5] + the known bond stereochemistries. See [1]_ for more details + + References + ---------- + .. [1] https://www.rdkit.org/docs/source/rdkit.Chem.rdchem.html#rdkit.Chem.rdchem.BondStereo.values + """ + + def __init__( + self, bond_types: Sequence[BondType] | None = None, stereos: Sequence[int] | None = None + ): + self.bond_types = bond_types or [ + BondType.SINGLE, + BondType.DOUBLE, + BondType.TRIPLE, + BondType.AROMATIC, + ] + self.stereo = stereos or range(6) + + def __len__(self): + return 1 + len(self.bond_types) + 2 + (len(self.stereo) + 1) + + def __call__(self, b: Bond) -> np.ndarray: + x = np.zeros(len(self), int) + + if b is None: + x[0] = 1 + return x + + i = 1 + bond_type = b.GetBondType() + bt_bit, size = self.one_hot_index(bond_type, self.bond_types) + if bt_bit != size: + x[i + bt_bit] = 1 + i += size - 1 + + x[i] = int(b.GetIsConjugated()) + x[i + 1] = int(b.IsInRing()) + i += 2 + + stereo_bit, _ = self.one_hot_index(int(b.GetStereo()), self.stereo) + x[i + stereo_bit] = 1 + + return x + + @classmethod + def one_hot_index(cls, x, xs: Sequence) -> tuple[int, int]: + """Returns a tuple of the index of ``x`` in ``xs`` and ``len(xs) + 1`` if ``x`` is in ``xs``. + Otherwise, returns a tuple with ``len(xs)`` and ``len(xs) + 1``.""" + n = len(xs) + + return xs.index(x) if x in xs else n, n + 1 + + +class RIGRBondFeaturizer(VectorFeaturizer[Bond]): + """A :class:`RIGRBondFeaturizer` feauturizes bonds based on only the resonance-invariant features: + + * ``null``-ity (i.e., is the bond ``None``?) + * in ring? + """ + + def __len__(self): + return 2 + + def __call__(self, b: Bond) -> np.ndarray: + x = np.zeros(len(self), int) + + if b is None: + x[0] = 1 + return x + + x[1] = int(b.IsInRing()) + + return x + + @classmethod + def one_hot_index(cls, x, xs: Sequence) -> tuple[int, int]: + """Returns a tuple of the index of ``x`` in ``xs`` and ``len(xs) + 1`` if ``x`` is in ``xs``. + Otherwise, returns a tuple with ``len(xs)`` and ``len(xs) + 1``.""" + n = len(xs) + + return xs.index(x) if x in xs else n, n + 1 diff --git a/chemprop-updated/chemprop/featurizers/molecule.py b/chemprop-updated/chemprop/featurizers/molecule.py new file mode 100644 index 0000000000000000000000000000000000000000..df35f066f27b64dc6524c54dbbc0c3e4d7233bdb --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/molecule.py @@ -0,0 +1,104 @@ +import logging + +from descriptastorus.descriptors import rdDescriptors, rdNormalizedDescriptors +import numpy as np +from rdkit import Chem +from rdkit.Chem import Descriptors, Mol +from rdkit.Chem.rdFingerprintGenerator import GetMorganGenerator + +from chemprop.featurizers.base import VectorFeaturizer +from chemprop.utils import ClassRegistry + +logger = logging.getLogger(__name__) + +MoleculeFeaturizerRegistry = ClassRegistry[VectorFeaturizer[Mol]]() + + +class MorganFeaturizerMixin: + def __init__(self, radius: int = 2, length: int = 2048, include_chirality: bool = True): + if radius < 0: + raise ValueError(f"arg 'radius' must be >= 0! got: {radius}") + + self.length = length + self.F = GetMorganGenerator( + radius=radius, fpSize=length, includeChirality=include_chirality + ) + + def __len__(self) -> int: + return self.length + + +class BinaryFeaturizerMixin: + def __call__(self, mol: Chem.Mol) -> np.ndarray: + return self.F.GetFingerprintAsNumPy(mol) + + +class CountFeaturizerMixin: + def __call__(self, mol: Chem.Mol) -> np.ndarray: + return self.F.GetCountFingerprintAsNumPy(mol).astype(np.int32) + + +@MoleculeFeaturizerRegistry("morgan_binary") +class MorganBinaryFeaturizer(MorganFeaturizerMixin, BinaryFeaturizerMixin, VectorFeaturizer[Mol]): + pass + + +@MoleculeFeaturizerRegistry("morgan_count") +class MorganCountFeaturizer(MorganFeaturizerMixin, CountFeaturizerMixin, VectorFeaturizer[Mol]): + pass + + +@MoleculeFeaturizerRegistry("rdkit_2d") +class RDKit2DFeaturizer(VectorFeaturizer[Mol]): + def __init__(self): + logger.warning( + "The RDKit 2D features can deviate signifcantly from a normal distribution. Consider " + "manually scaling them using an appropriate scaler before creating datapoints, rather " + "than using the scikit-learn `StandardScaler` (the default in Chemprop)." + ) + + def __len__(self) -> int: + return len(Descriptors.descList) + + def __call__(self, mol: Chem.Mol) -> np.ndarray: + features = np.array( + [ + 0.0 if name == "SPS" and mol.GetNumHeavyAtoms() == 0 else func(mol) + for name, func in Descriptors.descList + ], + dtype=float, + ) + + return features + + +class V1RDKit2DFeaturizerMixin(VectorFeaturizer[Mol]): + def __len__(self) -> int: + return 200 + + def __call__(self, mol: Mol) -> np.ndarray: + smiles = Chem.MolToSmiles(mol, isomericSmiles=True) + features = self.generator.process(smiles)[1:] + + return np.array(features) + + +@MoleculeFeaturizerRegistry("v1_rdkit_2d") +class V1RDKit2DFeaturizer(V1RDKit2DFeaturizerMixin): + def __init__(self): + self.generator = rdDescriptors.RDKit2D() + + +@MoleculeFeaturizerRegistry("v1_rdkit_2d_normalized") +class V1RDKit2DNormalizedFeaturizer(V1RDKit2DFeaturizerMixin): + def __init__(self): + self.generator = rdNormalizedDescriptors.RDKit2DNormalized() + + +@MoleculeFeaturizerRegistry("charge") +class ChargeFeaturizer(VectorFeaturizer[Mol]): + def __call__(self, mol: Chem.Mol) -> np.ndarray: + return np.array([Chem.GetFormalCharge(mol)]) + + def __len__(self) -> int: + return 1 diff --git a/chemprop-updated/chemprop/featurizers/molgraph/__init__.py b/chemprop-updated/chemprop/featurizers/molgraph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb21580de633d627d3144c55fe809d33466d26e --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/molgraph/__init__.py @@ -0,0 +1,13 @@ +from .cache import MolGraphCache, MolGraphCacheFacade, MolGraphCacheOnTheFly +from .molecule import SimpleMoleculeMolGraphFeaturizer +from .reaction import CGRFeaturizer, CondensedGraphOfReactionFeaturizer, RxnMode + +__all__ = [ + "MolGraphCacheFacade", + "MolGraphCache", + "MolGraphCacheOnTheFly", + "SimpleMoleculeMolGraphFeaturizer", + "CondensedGraphOfReactionFeaturizer", + "CGRFeaturizer", + "RxnMode", +] diff --git a/chemprop-updated/chemprop/featurizers/molgraph/cache.py b/chemprop-updated/chemprop/featurizers/molgraph/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..171d2b26f21c19d42539843d29c765b773651e2c --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/molgraph/cache.py @@ -0,0 +1,89 @@ +from abc import abstractmethod +from collections.abc import Sequence +from typing import Generic, Iterable + +import numpy as np + +from chemprop.data.molgraph import MolGraph +from chemprop.featurizers.base import Featurizer, S + + +class MolGraphCacheFacade(Sequence[MolGraph], Generic[S]): + """ + A :class:`MolGraphCacheFacade` provided an interface for caching + :class:`~chemprop.data.molgraph.MolGraph`\s. + + .. note:: + This class only provides a facade for a cached dataset, but it *does not guarantee* + whether the underlying data is truly cached. + + + Parameters + ---------- + inputs : Iterable[S] + The inputs to be featurized. + V_fs : Iterable[np.ndarray] + The node features for each input. + E_fs : Iterable[np.ndarray] + The edge features for each input. + featurizer : Featurizer[S, MolGraph] + The featurizer with which to generate the + :class:`~chemprop.data.molgraph.MolGraph`\s. + """ + + @abstractmethod + def __init__( + self, + inputs: Iterable[S], + V_fs: Iterable[np.ndarray], + E_fs: Iterable[np.ndarray], + featurizer: Featurizer[S, MolGraph], + ): + pass + + +class MolGraphCache(MolGraphCacheFacade): + """ + A :class:`MolGraphCache` precomputes the corresponding + :class:`~chemprop.data.molgraph.MolGraph`\s and caches them in memory. + """ + + def __init__( + self, + inputs: Iterable[S], + V_fs: Iterable[np.ndarray | None], + E_fs: Iterable[np.ndarray | None], + featurizer: Featurizer[S, MolGraph], + ): + self._mgs = [featurizer(input, V_f, E_f) for input, V_f, E_f in zip(inputs, V_fs, E_fs)] + + def __len__(self) -> int: + return len(self._mgs) + + def __getitem__(self, index: int) -> MolGraph: + return self._mgs[index] + + +class MolGraphCacheOnTheFly(MolGraphCacheFacade): + """ + A :class:`MolGraphCacheOnTheFly` computes the corresponding + :class:`~chemprop.data.molgraph.MolGraph`\s as they are requested. + """ + + def __init__( + self, + inputs: Iterable[S], + V_fs: Iterable[np.ndarray | None], + E_fs: Iterable[np.ndarray | None], + featurizer: Featurizer[S, MolGraph], + ): + self._inputs = list(inputs) + self._V_fs = list(V_fs) + self._E_fs = list(E_fs) + self._featurizer = featurizer + + def __len__(self) -> int: + return len(self._inputs) + + def __getitem__(self, index: int) -> MolGraph: + return self._featurizer(self._inputs[index], self._V_fs[index], self._E_fs[index]) diff --git a/chemprop-updated/chemprop/featurizers/molgraph/mixins.py b/chemprop-updated/chemprop/featurizers/molgraph/mixins.py new file mode 100644 index 0000000000000000000000000000000000000000..afa461d481388d51f6e8434a21a5f5f99199616a --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/molgraph/mixins.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass, field + +from rdkit.Chem.rdchem import Atom, Bond + +from chemprop.featurizers.atom import MultiHotAtomFeaturizer +from chemprop.featurizers.base import VectorFeaturizer +from chemprop.featurizers.bond import MultiHotBondFeaturizer + + +@dataclass +class _MolGraphFeaturizerMixin: + atom_featurizer: VectorFeaturizer[Atom] = field(default_factory=MultiHotAtomFeaturizer.v2) + bond_featurizer: VectorFeaturizer[Bond] = field(default_factory=MultiHotBondFeaturizer) + + def __post_init__(self): + self.atom_fdim = len(self.atom_featurizer) + self.bond_fdim = len(self.bond_featurizer) + + @property + def shape(self) -> tuple[int, int]: + """the feature dimension of the atoms and bonds, respectively, of `MolGraph`s generated by + this featurizer""" + return self.atom_fdim, self.bond_fdim diff --git a/chemprop-updated/chemprop/featurizers/molgraph/molecule.py b/chemprop-updated/chemprop/featurizers/molgraph/molecule.py new file mode 100644 index 0000000000000000000000000000000000000000..7ac7fafd4e15c57e1823ff0904e0888126c8352c --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/molgraph/molecule.py @@ -0,0 +1,91 @@ +from dataclasses import InitVar, dataclass + +import numpy as np +from rdkit import Chem +from rdkit.Chem import Mol + +from chemprop.data.molgraph import MolGraph +from chemprop.featurizers.base import GraphFeaturizer +from chemprop.featurizers.molgraph.mixins import _MolGraphFeaturizerMixin + + +@dataclass +class SimpleMoleculeMolGraphFeaturizer(_MolGraphFeaturizerMixin, GraphFeaturizer[Mol]): + """A :class:`SimpleMoleculeMolGraphFeaturizer` is the default implementation of a + :class:`MoleculeMolGraphFeaturizer` + + Parameters + ---------- + atom_featurizer : AtomFeaturizer, default=MultiHotAtomFeaturizer() + the featurizer with which to calculate feature representations of the atoms in a given + molecule + bond_featurizer : BondFeaturizer, default=MultiHotBondFeaturizer() + the featurizer with which to calculate feature representations of the bonds in a given + molecule + extra_atom_fdim : int, default=0 + the dimension of the additional features that will be concatenated onto the calculated + features of each atom + extra_bond_fdim : int, default=0 + the dimension of the additional features that will be concatenated onto the calculated + features of each bond + """ + + extra_atom_fdim: InitVar[int] = 0 + extra_bond_fdim: InitVar[int] = 0 + + def __post_init__(self, extra_atom_fdim: int = 0, extra_bond_fdim: int = 0): + super().__post_init__() + + self.extra_atom_fdim = extra_atom_fdim + self.extra_bond_fdim = extra_bond_fdim + self.atom_fdim += self.extra_atom_fdim + self.bond_fdim += self.extra_bond_fdim + + def __call__( + self, + mol: Chem.Mol, + atom_features_extra: np.ndarray | None = None, + bond_features_extra: np.ndarray | None = None, + ) -> MolGraph: + n_atoms = mol.GetNumAtoms() + n_bonds = mol.GetNumBonds() + + if atom_features_extra is not None and len(atom_features_extra) != n_atoms: + raise ValueError( + "Input molecule must have same number of atoms as `len(atom_features_extra)`!" + f"got: {n_atoms} and {len(atom_features_extra)}, respectively" + ) + if bond_features_extra is not None and len(bond_features_extra) != n_bonds: + raise ValueError( + "Input molecule must have same number of bonds as `len(bond_features_extra)`!" + f"got: {n_bonds} and {len(bond_features_extra)}, respectively" + ) + + if n_atoms == 0: + V = np.zeros((1, self.atom_fdim), dtype=np.single) + else: + V = np.array([self.atom_featurizer(a) for a in mol.GetAtoms()], dtype=np.single) + E = np.empty((2 * n_bonds, self.bond_fdim)) + edge_index = [[], []] + + if atom_features_extra is not None: + V = np.hstack((V, atom_features_extra)) + + i = 0 + for bond in mol.GetBonds(): + x_e = self.bond_featurizer(bond) + if bond_features_extra is not None: + x_e = np.concatenate((x_e, bond_features_extra[bond.GetIdx()]), dtype=np.single) + + E[i : i + 2] = x_e + + u, v = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + edge_index[0].extend([u, v]) + edge_index[1].extend([v, u]) + + i += 2 + + rev_edge_index = np.arange(len(E)).reshape(-1, 2)[:, ::-1].ravel() + edge_index = np.array(edge_index, int) + + return MolGraph(V, E, edge_index, rev_edge_index) diff --git a/chemprop-updated/chemprop/featurizers/molgraph/reaction.py b/chemprop-updated/chemprop/featurizers/molgraph/reaction.py new file mode 100644 index 0000000000000000000000000000000000000000..f35b03e037b45553743c0af53363a5f9d68585e9 --- /dev/null +++ b/chemprop-updated/chemprop/featurizers/molgraph/reaction.py @@ -0,0 +1,332 @@ +from dataclasses import InitVar, dataclass +from enum import auto +import logging +from typing import Iterable, Sequence, TypeAlias + +import numpy as np +from rdkit import Chem +from rdkit.Chem.rdchem import Bond, Mol + +from chemprop.data.molgraph import MolGraph +from chemprop.featurizers.base import GraphFeaturizer +from chemprop.featurizers.molgraph.mixins import _MolGraphFeaturizerMixin +from chemprop.types import Rxn +from chemprop.utils.utils import EnumMapping + +logger = logging.getLogger(__name__) + + +class RxnMode(EnumMapping): + """The mode by which a reaction should be featurized into a `MolGraph`""" + + REAC_PROD = auto() + """concatenate the reactant features with the product features.""" + REAC_PROD_BALANCE = auto() + """concatenate the reactant features with the products feature and balances imbalanced + reactions""" + REAC_DIFF = auto() + """concatenates the reactant features with the difference in features between reactants and + products""" + REAC_DIFF_BALANCE = auto() + """concatenates the reactant features with the difference in features between reactants and + product and balances imbalanced reactions""" + PROD_DIFF = auto() + """concatenates the product features with the difference in features between reactants and + products""" + PROD_DIFF_BALANCE = auto() + """concatenates the product features with the difference in features between reactants and + products and balances imbalanced reactions""" + + +@dataclass +class CondensedGraphOfReactionFeaturizer(_MolGraphFeaturizerMixin, GraphFeaturizer[Rxn]): + """A :class:`CondensedGraphOfReactionFeaturizer` featurizes reactions using the condensed + reaction graph method utilized in [1]_ + + **NOTE**: This class *does not* accept a :class:`AtomFeaturizer` instance. This is because + it requries the :meth:`num_only()` method, which is only implemented in the concrete + :class:`AtomFeaturizer` class + + Parameters + ---------- + atom_featurizer : AtomFeaturizer, default=AtomFeaturizer() + the featurizer with which to calculate feature representations of the atoms in a given + molecule + bond_featurizer : BondFeaturizerBase, default=BondFeaturizer() + the featurizer with which to calculate feature representations of the bonds in a given + molecule + mode_ : Union[str, ReactionMode], default=ReactionMode.REAC_DIFF + the mode by which to featurize the reaction as either the string code or enum value + + References + ---------- + .. [1] Heid, E.; Green, W.H. "Machine Learning of Reaction Properties via Learned + Representations of the Condensed Graph of Reaction." J. Chem. Inf. Model. 2022, 62, + 2101-2110. https://doi.org/10.1021/acs.jcim.1c00975 + """ + + mode_: InitVar[str | RxnMode] = RxnMode.REAC_DIFF + + def __post_init__(self, mode_: str | RxnMode): + super().__post_init__() + + self.mode = mode_ + self.atom_fdim += len(self.atom_featurizer) - len(self.atom_featurizer.atomic_nums) - 1 + self.bond_fdim *= 2 + + @property + def mode(self) -> RxnMode: + return self.__mode + + @mode.setter + def mode(self, m: str | RxnMode): + self.__mode = RxnMode.get(m) + + def __call__( + self, + rxn: tuple[Chem.Mol, Chem.Mol], + atom_features_extra: np.ndarray | None = None, + bond_features_extra: np.ndarray | None = None, + ) -> MolGraph: + """Featurize the input reaction into a molecular graph + + Parameters + ---------- + rxn : Rxn + a 2-tuple of atom-mapped rdkit molecules, where the 0th element is the reactant and the + 1st element is the product + atom_features_extra : np.ndarray | None, default=None + *UNSUPPORTED* maintained only to maintain parity with the method signature of the + `MoleculeFeaturizer` + bond_features_extra : np.ndarray | None, default=None + *UNSUPPORTED* maintained only to maintain parity with the method signature of the + `MoleculeFeaturizer` + + Returns + ------- + MolGraph + the molecular graph of the reaction + """ + + if atom_features_extra is not None: + logger.warning("'atom_features_extra' is currently unsupported for reactions") + if bond_features_extra is not None: + logger.warning("'bond_features_extra' is currently unsupported for reactions") + + reac, pdt = rxn + r2p_idx_map, pdt_idxs, reac_idxs = self.map_reac_to_prod(reac, pdt) + + V = self._calc_node_feature_matrix(reac, pdt, r2p_idx_map, pdt_idxs, reac_idxs) + E = [] + edge_index = [[], []] + + n_atoms_tot = len(V) + n_atoms_reac = reac.GetNumAtoms() + + for u in range(n_atoms_tot): + for v in range(u + 1, n_atoms_tot): + b_reac, b_prod = self._get_bonds( + reac, pdt, r2p_idx_map, pdt_idxs, n_atoms_reac, u, v + ) + if b_reac is None and b_prod is None: + continue + + x_e = self._calc_edge_feature(b_reac, b_prod) + E.extend([x_e, x_e]) + edge_index[0].extend([u, v]) + edge_index[1].extend([v, u]) + + E = np.array(E) if len(E) > 0 else np.empty((0, self.bond_fdim)) + rev_edge_index = np.arange(len(E)).reshape(-1, 2)[:, ::-1].ravel() + edge_index = np.array(edge_index, int) + + return MolGraph(V, E, edge_index, rev_edge_index) + + def _calc_node_feature_matrix( + self, + rct: Mol, + pdt: Mol, + r2p_idx_map: dict[int, int], + pdt_idxs: Iterable[int], + reac_idxs: Iterable[int], + ) -> np.ndarray: + """Calculate the node feature matrix for the reaction""" + X_v_r1 = np.array([self.atom_featurizer(a) for a in rct.GetAtoms()]) + X_v_p2 = np.array([self.atom_featurizer(pdt.GetAtomWithIdx(i)) for i in pdt_idxs]) + X_v_p2 = X_v_p2.reshape(-1, X_v_r1.shape[1]) + + if self.mode in [RxnMode.REAC_DIFF, RxnMode.PROD_DIFF, RxnMode.REAC_PROD]: + # Reactant: + # (1) regular features for each atom in the reactants + # (2) zero features for each atom that's only in the products + X_v_r2 = [self.atom_featurizer.num_only(pdt.GetAtomWithIdx(i)) for i in pdt_idxs] + X_v_r2 = np.array(X_v_r2).reshape(-1, X_v_r1.shape[1]) + + # Product: + # (1) either (a) product-side features for each atom in both + # or (b) zero features for each atom only in the reatants + # (2) regular features for each atom only in the products + X_v_p1 = np.array( + [ + ( + self.atom_featurizer(pdt.GetAtomWithIdx(r2p_idx_map[a.GetIdx()])) + if a.GetIdx() not in reac_idxs + else self.atom_featurizer.num_only(a) + ) + for a in rct.GetAtoms() + ] + ) + else: + # Reactant: + # (1) regular features for each atom in the reactants + # (2) regular features for each atom only in the products + X_v_r2 = [self.atom_featurizer(pdt.GetAtomWithIdx(i)) for i in pdt_idxs] + X_v_r2 = np.array(X_v_r2).reshape(-1, X_v_r1.shape[1]) + + # Product: + # (1) either (a) product-side features for each atom in both + # or (b) reactant-side features for each atom only in the reatants + # (2) regular features for each atom only in the products + X_v_p1 = np.array( + [ + ( + self.atom_featurizer(pdt.GetAtomWithIdx(r2p_idx_map[a.GetIdx()])) + if a.GetIdx() not in reac_idxs + else self.atom_featurizer(a) + ) + for a in rct.GetAtoms() + ] + ) + + X_v_r = np.concatenate((X_v_r1, X_v_r2)) + X_v_p = np.concatenate((X_v_p1, X_v_p2)) + + m = min(len(X_v_r), len(X_v_p)) + + if self.mode in [RxnMode.REAC_PROD, RxnMode.REAC_PROD_BALANCE]: + X_v = np.hstack((X_v_r[:m], X_v_p[:m, len(self.atom_featurizer.atomic_nums) + 1 :])) + else: + X_v_d = X_v_p[:m] - X_v_r[:m] + if self.mode in [RxnMode.REAC_DIFF, RxnMode.REAC_DIFF_BALANCE]: + X_v = np.hstack((X_v_r[:m], X_v_d[:m, len(self.atom_featurizer.atomic_nums) + 1 :])) + else: + X_v = np.hstack((X_v_p[:m], X_v_d[:m, len(self.atom_featurizer.atomic_nums) + 1 :])) + + return X_v + + def _get_bonds( + self, + rct: Bond, + pdt: Bond, + ri2pj: dict[int, int], + pids: Sequence[int], + n_atoms_r: int, + u: int, + v: int, + ) -> tuple[Bond, Bond]: + """get the corresponding reactant- and product-side bond, respectively, betweeen atoms `u` and `v`""" + if u >= n_atoms_r and v >= n_atoms_r: + b_prod = pdt.GetBondBetweenAtoms(pids[u - n_atoms_r], pids[v - n_atoms_r]) + + if self.mode in [ + RxnMode.REAC_PROD_BALANCE, + RxnMode.REAC_DIFF_BALANCE, + RxnMode.PROD_DIFF_BALANCE, + ]: + b_reac = b_prod + else: + b_reac = None + elif u < n_atoms_r and v >= n_atoms_r: # One atom only in product + b_reac = None + + if u in ri2pj: + b_prod = pdt.GetBondBetweenAtoms(ri2pj[u], pids[v - n_atoms_r]) + else: # Atom atom only in reactant, the other only in product + b_prod = None + else: + b_reac = rct.GetBondBetweenAtoms(u, v) + + if u in ri2pj and v in ri2pj: # Both atoms in both reactant and product + b_prod = pdt.GetBondBetweenAtoms(ri2pj[u], ri2pj[v]) + elif self.mode in [ + RxnMode.REAC_PROD_BALANCE, + RxnMode.REAC_DIFF_BALANCE, + RxnMode.PROD_DIFF_BALANCE, + ]: + b_prod = None if (u in ri2pj or v in ri2pj) else b_reac + else: # One or both atoms only in reactant + b_prod = None + + return b_reac, b_prod + + def _calc_edge_feature(self, b_reac: Bond, b_pdt: Bond): + """Calculate the global features of the two bonds""" + x_e_r = self.bond_featurizer(b_reac) + x_e_p = self.bond_featurizer(b_pdt) + x_e_d = x_e_p - x_e_r + + if self.mode in [RxnMode.REAC_PROD, RxnMode.REAC_PROD_BALANCE]: + x_e = np.hstack((x_e_r, x_e_p)) + elif self.mode in [RxnMode.REAC_DIFF, RxnMode.REAC_DIFF_BALANCE]: + x_e = np.hstack((x_e_r, x_e_d)) + else: + x_e = np.hstack((x_e_p, x_e_d)) + + return x_e + + @classmethod + def map_reac_to_prod( + cls, reacs: Chem.Mol, pdts: Chem.Mol + ) -> tuple[dict[int, int], list[int], list[int]]: + """Map atom indices between corresponding atoms in the reactant and product molecules + + Parameters + ---------- + reacs : Chem.Mol + An RDKit molecule of the reactants + pdts : Chem.Mol + An RDKit molecule of the products + + Returns + ------- + ri2pi : dict[int, int] + A dictionary of corresponding atom indices from reactant atoms to product atoms + pdt_idxs : list[int] + atom indices of poduct atoms + rct_idxs : list[int] + atom indices of reactant atoms + """ + pdt_idxs = [] + mapno2pj = {} + reac_atommap_nums = {a.GetAtomMapNum() for a in reacs.GetAtoms()} + + for a in pdts.GetAtoms(): + map_num = a.GetAtomMapNum() + j = a.GetIdx() + + if map_num > 0: + mapno2pj[map_num] = j + if map_num not in reac_atommap_nums: + pdt_idxs.append(j) + else: + pdt_idxs.append(j) + + rct_idxs = [] + r2p_idx_map = {} + + for a in reacs.GetAtoms(): + map_num = a.GetAtomMapNum() + i = a.GetIdx() + + if map_num > 0: + try: + r2p_idx_map[i] = mapno2pj[map_num] + except KeyError: + rct_idxs.append(i) + else: + rct_idxs.append(i) + + return r2p_idx_map, pdt_idxs, rct_idxs + + +CGRFeaturizer: TypeAlias = CondensedGraphOfReactionFeaturizer diff --git a/chemprop-updated/chemprop/models/__init__.py b/chemprop-updated/chemprop/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..76946d73b599668b8cdd7adc2a1c48b38b8d1108 --- /dev/null +++ b/chemprop-updated/chemprop/models/__init__.py @@ -0,0 +1,5 @@ +from .model import MPNN +from .multi import MulticomponentMPNN +from .utils import load_model, save_model + +__all__ = ["MPNN", "MulticomponentMPNN", "load_model", "save_model"] diff --git a/chemprop-updated/chemprop/models/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6503ed123f3d65114e035321f0cbf9882c566608 Binary files /dev/null and b/chemprop-updated/chemprop/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/models/__pycache__/ffn.cpython-37.pyc b/chemprop-updated/chemprop/models/__pycache__/ffn.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2b4f0b097a79f0ccabf82aae5f2e77c672a9eb5 Binary files /dev/null and b/chemprop-updated/chemprop/models/__pycache__/ffn.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/models/__pycache__/model.cpython-37.pyc b/chemprop-updated/chemprop/models/__pycache__/model.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab53bf808264e8ab0121ba648c645b3852d114cd Binary files /dev/null and b/chemprop-updated/chemprop/models/__pycache__/model.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/models/__pycache__/mpn.cpython-37.pyc b/chemprop-updated/chemprop/models/__pycache__/mpn.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e647541f4fb9aa7804d2c7393f5f55c7bff876a Binary files /dev/null and b/chemprop-updated/chemprop/models/__pycache__/mpn.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/models/model.py b/chemprop-updated/chemprop/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c953638d55329f94402615f7ab359a240a755b0c --- /dev/null +++ b/chemprop-updated/chemprop/models/model.py @@ -0,0 +1,374 @@ +from __future__ import annotations + +import io +import logging +from typing import Iterable, TypeAlias + +from lightning import pytorch as pl +import torch +from torch import Tensor, nn, optim + +from chemprop.data import BatchMolGraph, MulticomponentTrainingBatch, TrainingBatch +from chemprop.nn import Aggregation, ChempropMetric, MessagePassing, Predictor +from chemprop.nn.transforms import ScaleTransform +from chemprop.schedulers import build_NoamLike_LRSched +from chemprop.utils.registry import Factory + +logger = logging.getLogger(__name__) + +BatchType: TypeAlias = TrainingBatch | MulticomponentTrainingBatch + +import pandas as pd +from transformers import RobertaTokenizer, RobertaModel +from torch.utils.data import DataLoader +import torch.nn.functional as F + +class ChemBERTaEncoder(nn.Module): + def __init__(self, model_name="DeepChem/ChemBERTa-77M-MLM", fine_tune_percent=10, unfreeze_pooler=True): + super().__init__() + self.tokenizer = RobertaTokenizer.from_pretrained(model_name) + self.encoder = RobertaModel.from_pretrained(model_name) + + # Step 1: Freeze all parameters + for param in self.encoder.parameters(): + param.requires_grad = False + + # Step 2: Unfreeze the top k layers based on fine_tune_percent + num_layers_total = len(self.encoder.encoder.layer) + k = max(1, int(num_layers_total * fine_tune_percent / 100)) + + for layer in self.encoder.encoder.layer[-k:]: # Unfreeze top k layers + for param in layer.parameters(): + param.requires_grad = True + + # Logging + total = sum(p.numel() for p in self.encoder.parameters()) + trainable = sum(p.numel() for p in self.encoder.parameters() if p.requires_grad) + print(f"ChemBERTa Total parameters: {total}") + print(f"Trainable parameters: {trainable} ({100 * trainable / total:.2f}%)") + + def encode(self, smiles_list: list[str], batch_size=64, max_length=128): + device = next(self.encoder.parameters()).device + all_hidden_states = [] + all_pooler_outputs = [] + + for i in range(0, len(smiles_list), batch_size): + batch = smiles_list[i:i+batch_size] + inputs = self.tokenizer(batch, padding=True, truncation=True, return_tensors="pt", max_length=max_length) + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.set_grad_enabled(self.encoder.training): + outputs = self.encoder(**inputs) + last_hidden = outputs.last_hidden_state.detach().clone() # [B, L, d_model] + pooler = outputs.pooler_output.detach().clone() # [B, d_model] + all_hidden_states.append(last_hidden) + all_pooler_outputs.append(pooler) + + # Return both as tensors + return { + "last_hidden_state": torch.cat(all_hidden_states, dim=0), + "pooler_output": torch.cat(all_pooler_outputs, dim=0) + } + +class fusionGAT(nn.Module): + def __init__(self, dmpnn_dim: int, bert_dim: int, hidden_dim: int): + super().__init__() + # Project descriptor and nodes to hidden_dim + self.W_dmpnn = nn.Linear(dmpnn_dim, hidden_dim) + self.W_bert = nn.Linear(bert_dim, hidden_dim) + self.attn_fc = nn.Linear(2 * hidden_dim, 1) + self.leaky_relu = nn.LeakyReLU(0.2) + + def forward(self, dmpnn_output: Tensor, encodings: Tensor) -> Tensor: + """ + desc: (B, dmpnn_dim) + nodes: (B, L, bert_dim) + + Returns: + updated_desc: (B, hidden_dim) + """ + B, L, _ = encodings.size() + + dmpnn_proj = self.W_dmpnn(dmpnn_output) # (B, hidden_dim) + bert_proj = self.W_bert(encodings) # (B, L, hidden_dim) + + # Expand dmpnn_proj to (B, L, hidden_dim) to concatenate with each node + dmpnn_expanded = dmpnn_proj.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim) + + # Concatenate dmpnn and bert features + cat = torch.cat([dmpnn_expanded, bert_proj], dim=-1) # (B, L, 2*hidden_dim) + + # Compute attention scores + e = self.leaky_relu(self.attn_fc(cat)).squeeze(-1) # (B, L) + + # Attention weights over L nodes + alpha = torch.softmax(e, dim=1).unsqueeze(-1) # (B, L, 1) + + # Weighted sum of node features + fusion = torch.sum(alpha * bert_proj, dim=1) # (B, hidden_dim) + + return fusion + + + +class MPNN(pl.LightningModule): + def __init__( + self, + message_passing: MessagePassing, + agg: Aggregation, + predictor: Predictor, + batch_norm: bool = False, + metrics: Iterable[ChempropMetric] | None = None, + warmup_epochs: int = 2, + init_lr: float = 1e-4, + max_lr: float = 1e-3, + final_lr: float = 1e-4, + X_d_transform: ScaleTransform | None = None, + fine_tune_bert: bool = True, + fine_tune_percent: int = 10 + ): + super().__init__() + self.save_hyperparameters(ignore=["X_d_transform", "message_passing", "agg", "predictor"]) + self.hparams["X_d_transform"] = X_d_transform + self.hparams.update({ + "message_passing": message_passing.hparams, + "agg": agg.hparams, + "predictor": predictor.hparams, + }) + + self.fusion_GAT = fusionGAT( + dmpnn_dim=message_passing.output_dim, + bert_dim=768, + hidden_dim=message_passing.output_dim + ) + + self.message_passing = message_passing + self.agg = agg + self.bn = nn.BatchNorm1d(self.message_passing.output_dim) if batch_norm else nn.Identity() + self.predictor = predictor + self.X_d_transform = X_d_transform if X_d_transform is not None else nn.Identity() + + self.metrics = ( + nn.ModuleList([*metrics, self.criterion.clone()]) + if metrics + else nn.ModuleList([self.predictor._T_default_metric(), self.criterion.clone()]) + ) + + self.warmup_epochs = warmup_epochs + self.init_lr = init_lr + self.max_lr = max_lr + self.final_lr = final_lr + + self.fine_tune_bert = fine_tune_bert + self.fine_tune_percent = fine_tune_percent + + + self.bert_encoder = ChemBERTaEncoder( + model_name="seyonec/ChemBERTa-zinc-base-v1", + fine_tune_percent=self.fine_tune_percent if self.fine_tune_bert else 0 + ) + + self.bert_encoder = self.bert_encoder.to(self.device) + + + @property + def output_dim(self) -> int: + return self.predictor.output_dim + + @property + def n_tasks(self) -> int: + return self.predictor.n_tasks + + @property + def n_targets(self) -> int: + return self.predictor.n_targets + + @property + def criterion(self) -> ChempropMetric: + return self.predictor.criterion + + def fingerprint(self, bmg: BatchMolGraph, V_d: Tensor | None = None, X_d: Tensor | None = None) -> Tensor: + H_v = self.message_passing(bmg, V_d) + H = self.agg(H_v, bmg.batch) + + smiles_list = bmg.names + outputs = self.bert_encoder.encode(smiles_list) + output_pooler = outputs["last_hidden_state"] + + fingerprint = self.fusion_GAT(H, output_pooler) + fingerprint = self.bn(fingerprint) + + return fingerprint if X_d is None else torch.cat((fingerprint, self.X_d_transform(X_d)), 1) + + def encoding(self, bmg: BatchMolGraph, V_d: Tensor | None = None, X_d: Tensor | None = None, i: int = -1) -> Tensor: + return self.predictor.encode(self.fingerprint(bmg, V_d, X_d), i) + + def forward(self, bmg: BatchMolGraph, V_d: Tensor | None = None, X_d: Tensor | None = None) -> Tensor: + return self.predictor(self.fingerprint(bmg, V_d, X_d)) + + def training_step(self, batch: BatchType, batch_idx): + batch_size = self.get_batch_size(batch) + bmg, V_d, X_d, targets, weights, lt_mask, gt_mask = batch + + mask = targets.isfinite() + targets = targets.nan_to_num(nan=0.0) + + Z = self.fingerprint(bmg, V_d, X_d) + preds = self.predictor.train_step(Z) + l = self.criterion(preds, targets, mask, weights, lt_mask, gt_mask) + + self.log("train_loss", self.criterion, batch_size=batch_size, prog_bar=True, on_epoch=True) + return l + + def on_validation_model_eval(self) -> None: + self.eval() + self.message_passing.V_d_transform.train() + self.message_passing.graph_transform.train() + self.X_d_transform.train() + self.predictor.output_transform.train() + + if self.fine_tune_bert: + self.bert_encoder.encoder.train() + + def validation_step(self, batch: BatchType, batch_idx: int = 0): + self._evaluate_batch(batch, "val") + + batch_size = self.get_batch_size(batch) + bmg, V_d, X_d, targets, weights, lt_mask, gt_mask = batch + + mask = targets.isfinite() + targets = targets.nan_to_num(nan=0.0) + + Z = self.fingerprint(bmg, V_d, X_d) + preds = self.predictor.train_step(Z) + self.metrics[-1](preds, targets, mask, weights, lt_mask, gt_mask) + self.log("val_loss", self.metrics[-1], batch_size=batch_size, prog_bar=True) + + def test_step(self, batch: BatchType, batch_idx: int = 0): + self._evaluate_batch(batch, "test") + + def _evaluate_batch(self, batch: BatchType, label: str) -> None: + batch_size = self.get_batch_size(batch) + bmg, V_d, X_d, targets, weights, lt_mask, gt_mask = batch + + mask = targets.isfinite() + targets = targets.nan_to_num(nan=0.0) + preds = self(bmg, V_d, X_d) + weights = torch.ones_like(weights) + + if self.predictor.n_targets > 1: + preds = preds[..., 0] + + for m in self.metrics[:-1]: + m.update(preds, targets, mask, weights, lt_mask, gt_mask) + self.log(f"{label}/{m.alias}", m, batch_size=batch_size) + + def predict_step(self, batch: BatchType, batch_idx: int, dataloader_idx: int = 0) -> Tensor: + bmg, X_vd, X_d, *_ = batch + return self(bmg, X_vd, X_d) + + def configure_optimizers(self): + opt = optim.Adam(self.parameters(), self.init_lr) + if self.trainer.train_dataloader is None: + self.trainer.estimated_stepping_batches + steps_per_epoch = self.trainer.num_training_batches + warmup_steps = self.warmup_epochs * steps_per_epoch + if self.trainer.max_epochs == -1: + logger.warning( + "For infinite training, the number of cooldown epochs in learning rate scheduler is set to 100 times the number of warmup epochs." + ) + cooldown_steps = 100 * warmup_steps + else: + cooldown_epochs = self.trainer.max_epochs - self.warmup_epochs + cooldown_steps = cooldown_epochs * steps_per_epoch + + lr_sched = build_NoamLike_LRSched( + opt, warmup_steps, cooldown_steps, self.init_lr, self.max_lr, self.final_lr + ) + + return {"optimizer": opt, "lr_scheduler": {"scheduler": lr_sched, "interval": "step"}} + + def get_batch_size(self, batch: TrainingBatch) -> int: + return len(batch[0]) + + @classmethod + def _load(cls, path, map_location, **submodules): + d = torch.load(path, map_location, weights_only=False) + + try: + hparams = d["hyper_parameters"] + state_dict = d["state_dict"] + except KeyError: + raise KeyError(f"Could not find hyper parameters and/or state dict in {path}.") + + if hparams["metrics"] is not None: + hparams["metrics"] = [ + cls._rebuild_metric(metric) + if not hasattr(metric, "_defaults") + or (not torch.cuda.is_available() and metric.device.type != "cpu") + else metric + for metric in hparams["metrics"] + ] + + if hparams["predictor"]["criterion"] is not None: + metric = hparams["predictor"]["criterion"] + if not hasattr(metric, "_defaults") or ( + not torch.cuda.is_available() and metric.device.type != "cpu" + ): + hparams["predictor"]["criterion"] = cls._rebuild_metric(metric) + + submodules |= { + key: hparams[key].pop("cls")(**hparams[key]) + for key in ("message_passing", "agg", "predictor") + if key not in submodules + } + + return submodules, state_dict, hparams + + @classmethod + def _add_metric_task_weights_to_state_dict(cls, state_dict, hparams): + if "metrics.0.task_weights" not in state_dict: + metrics = hparams["metrics"] + n_metrics = len(metrics) if metrics is not None else 1 + for i_metric in range(n_metrics): + state_dict[f"metrics.{i_metric}.task_weights"] = torch.tensor([[1.0]]) + state_dict[f"metrics.{i_metric + 1}.task_weights"] = state_dict[ + "predictor.criterion.task_weights" + ] + return state_dict + + @classmethod + def _rebuild_metric(cls, metric): + return Factory.build(metric.__class__, task_weights=metric.task_weights, **metric.__dict__) + + @classmethod + def load_from_checkpoint( + cls, checkpoint_path, map_location=None, hparams_file=None, strict=True, **kwargs + ) -> MPNN: + submodules = { + k: v for k, v in kwargs.items() if k in ["message_passing", "agg", "predictor"] + } + submodules, state_dict, hparams = cls._load(checkpoint_path, map_location, **submodules) + kwargs.update(submodules) + + state_dict = cls._add_metric_task_weights_to_state_dict(state_dict, hparams) + d = torch.load(checkpoint_path, map_location, weights_only=False) + d["state_dict"] = state_dict + d["hyper_parameters"] = hparams + buffer = io.BytesIO() + torch.save(d, buffer) + buffer.seek(0) + + return super().load_from_checkpoint(buffer, map_location, hparams_file, strict, **kwargs) + + @classmethod + def load_from_file(cls, model_path, map_location=None, strict=True, **submodules) -> MPNN: + submodules, state_dict, hparams = cls._load(model_path, map_location, **submodules) + hparams.update(submodules) + + state_dict = cls._add_metric_task_weights_to_state_dict(state_dict, hparams) + + model = cls(**hparams) + model.load_state_dict(state_dict, strict=strict) + + return model \ No newline at end of file diff --git a/chemprop-updated/chemprop/models/multi.py b/chemprop-updated/chemprop/models/multi.py new file mode 100644 index 0000000000000000000000000000000000000000..930b815b1e8f8688101ab8ce14697f54c41b3e0e --- /dev/null +++ b/chemprop-updated/chemprop/models/multi.py @@ -0,0 +1,101 @@ +from typing import Iterable + +import torch +from torch import Tensor + +from chemprop.data import BatchMolGraph, MulticomponentTrainingBatch +from chemprop.models.model import MPNN +from chemprop.nn import Aggregation, MulticomponentMessagePassing, Predictor +from chemprop.nn.metrics import ChempropMetric +from chemprop.nn.transforms import ScaleTransform + + +class MulticomponentMPNN(MPNN): + def __init__( + self, + message_passing: MulticomponentMessagePassing, + agg: Aggregation, + predictor: Predictor, + batch_norm: bool = False, + metrics: Iterable[ChempropMetric] | None = None, + warmup_epochs: int = 2, + init_lr: float = 1e-4, + max_lr: float = 1e-3, + final_lr: float = 1e-4, + X_d_transform: ScaleTransform | None = None, + ): + super().__init__( + message_passing, + agg, + predictor, + batch_norm, + metrics, + warmup_epochs, + init_lr, + max_lr, + final_lr, + X_d_transform, + ) + self.message_passing: MulticomponentMessagePassing + + def fingerprint( + self, + bmgs: Iterable[BatchMolGraph], + V_ds: Iterable[Tensor | None], + X_d: Tensor | None = None, + ) -> Tensor: + H_vs: list[Tensor] = self.message_passing(bmgs, V_ds) + Hs = [self.agg(H_v, bmg.batch) for H_v, bmg in zip(H_vs, bmgs)] + H = torch.cat(Hs, 1) + H = self.bn(H) + + return H if X_d is None else torch.cat((H, self.X_d_transform(X_d)), 1) + + def on_validation_model_eval(self) -> None: + self.eval() + for block in self.message_passing.blocks: + block.V_d_transform.train() + block.graph_transform.train() + self.X_d_transform.train() + self.predictor.output_transform.train() + + def get_batch_size(self, batch: MulticomponentTrainingBatch) -> int: + return len(batch[0][0]) + + @classmethod + def _load(cls, path, map_location, **submodules): + d = torch.load(path, map_location, weights_only=False) + + try: + hparams = d["hyper_parameters"] + state_dict = d["state_dict"] + except KeyError: + raise KeyError(f"Could not find hyper parameters and/or state dict in {path}.") + + if hparams["metrics"] is not None: + hparams["metrics"] = [ + cls._rebuild_metric(metric) + if not hasattr(metric, "_defaults") + or (not torch.cuda.is_available() and metric.device.type != "cpu") + else metric + for metric in hparams["metrics"] + ] + + if hparams["predictor"]["criterion"] is not None: + metric = hparams["predictor"]["criterion"] + if not hasattr(metric, "_defaults") or ( + not torch.cuda.is_available() and metric.device.type != "cpu" + ): + hparams["predictor"]["criterion"] = cls._rebuild_metric(metric) + + hparams["message_passing"]["blocks"] = [ + block_hparams.pop("cls")(**block_hparams) + for block_hparams in hparams["message_passing"]["blocks"] + ] + submodules |= { + key: hparams[key].pop("cls")(**hparams[key]) + for key in ("message_passing", "agg", "predictor") + if key not in submodules + } + + return submodules, state_dict, hparams diff --git a/chemprop-updated/chemprop/models/utils.py b/chemprop-updated/chemprop/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf0d06b5fb8e7841856f3a143e10e16701d62783 --- /dev/null +++ b/chemprop-updated/chemprop/models/utils.py @@ -0,0 +1,32 @@ +from os import PathLike + +import torch + +from chemprop.models.model import MPNN +from chemprop.models.multi import MulticomponentMPNN + + +def save_model(path: PathLike, model: MPNN, output_columns: list[str] = None) -> None: + torch.save( + { + "hyper_parameters": model.hparams, + "state_dict": model.state_dict(), + "output_columns": output_columns, + }, + path, + ) + + +def load_model(path: PathLike, multicomponent: bool) -> MPNN: + if multicomponent: + model = MulticomponentMPNN.load_from_file(path, map_location=torch.device("cpu")) + else: + model = MPNN.load_from_file(path, map_location=torch.device("cpu")) + + return model + + +def load_output_columns(path: PathLike) -> list[str] | None: + model_file = torch.load(path, map_location=torch.device("cpu"), weights_only=False) + + return model_file.get("output_columns") diff --git a/chemprop-updated/chemprop/nn/__init__.py b/chemprop-updated/chemprop/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e8680ede568cfe936fefe16626f3bbdf39a5725 --- /dev/null +++ b/chemprop-updated/chemprop/nn/__init__.py @@ -0,0 +1,127 @@ +from .agg import ( + Aggregation, + AggregationRegistry, + AttentiveAggregation, + MeanAggregation, + NormAggregation, + SumAggregation, +) +from .message_passing import ( + AtomMessagePassing, + BondMessagePassing, + MessagePassing, + MulticomponentMessagePassing, +) +from .metrics import ( + MAE, + MSE, + RMSE, + SID, + BCELoss, + BinaryAccuracy, + BinaryAUPRC, + BinaryAUROC, + BinaryF1Score, + BinaryMCCLoss, + BinaryMCCMetric, + BoundedMAE, + BoundedMixin, + BoundedMSE, + BoundedRMSE, + ChempropMetric, + ClassificationMixin, + CrossEntropyLoss, + DirichletLoss, + EvidentialLoss, + LossFunctionRegistry, + MetricRegistry, + MulticlassMCCLoss, + MulticlassMCCMetric, + MVELoss, + QuantileLoss, + R2Score, + Wasserstein, +) +from .predictors import ( + BinaryClassificationFFN, + BinaryClassificationFFNBase, + BinaryDirichletFFN, + EvidentialFFN, + MulticlassClassificationFFN, + MulticlassDirichletFFN, + MveFFN, + Predictor, + PredictorRegistry, + QuantileFFN, + RegressionFFN, + SpectralFFN, +) +from .transforms import GraphTransform, ScaleTransform, UnscaleTransform +from .utils import Activation + +__all__ = [ + "Aggregation", + "AggregationRegistry", + "MeanAggregation", + "SumAggregation", + "NormAggregation", + "AttentiveAggregation", + "ChempropMetric", + "ClassificationMixin", + "LossFunctionRegistry", + "MetricRegistry", + "MSE", + "MAE", + "RMSE", + "BoundedMixin", + "BoundedMSE", + "BoundedMAE", + "BoundedRMSE", + "BinaryAccuracy", + "BinaryAUPRC", + "BinaryAUROC", + "BinaryF1Score", + "BinaryMCCMetric", + "BoundedMAE", + "BoundedMSE", + "BoundedRMSE", + "MetricRegistry", + "MulticlassMCCMetric", + "R2Score", + "MVELoss", + "EvidentialLoss", + "BCELoss", + "CrossEntropyLoss", + "BinaryMCCLoss", + "BinaryMCCMetric", + "MulticlassMCCLoss", + "MulticlassMCCMetric", + "BinaryAUROC", + "BinaryAUPRC", + "BinaryAccuracy", + "BinaryF1Score", + "MulticlassDirichletLoss", + "SID", + "Wasserstein", + "QuantileLoss", + "MessagePassing", + "AtomMessagePassing", + "BondMessagePassing", + "MulticomponentMessagePassing", + "Predictor", + "PredictorRegistry", + "QuantileFFN", + "RegressionFFN", + "MveFFN", + "DirichletLoss", + "EvidentialFFN", + "BinaryClassificationFFNBase", + "BinaryClassificationFFN", + "BinaryDirichletFFN", + "MulticlassClassificationFFN", + "SpectralFFN", + "Activation", + "GraphTransform", + "ScaleTransform", + "UnscaleTransform", +] diff --git a/chemprop-updated/chemprop/nn/agg.py b/chemprop-updated/chemprop/nn/agg.py new file mode 100644 index 0000000000000000000000000000000000000000..ed921b41d41f68534931a93c56552f31bd792d34 --- /dev/null +++ b/chemprop-updated/chemprop/nn/agg.py @@ -0,0 +1,133 @@ +from abc import abstractmethod + +import torch +from torch import Tensor, nn + +from chemprop.nn.hparams import HasHParams +from chemprop.utils import ClassRegistry + +__all__ = [ + "Aggregation", + "AggregationRegistry", + "MeanAggregation", + "SumAggregation", + "NormAggregation", + "AttentiveAggregation", +] + + +class Aggregation(nn.Module, HasHParams): + """An :class:`Aggregation` aggregates the node-level representations of a batch of graphs into + a batch of graph-level representations + + .. note:: + this class is abstract and cannot be instantiated. + + See also + -------- + :class:`~chemprop.v2.models.modules.agg.MeanAggregation` + :class:`~chemprop.v2.models.modules.agg.SumAggregation` + :class:`~chemprop.v2.models.modules.agg.NormAggregation` + """ + + def __init__(self, dim: int = 0, *args, **kwargs): + super().__init__() + + self.dim = dim + self.hparams = {"dim": dim, "cls": self.__class__} + + @abstractmethod + def forward(self, H: Tensor, batch: Tensor) -> Tensor: + """Aggregate the graph-level representations of a batch of graphs into their respective + global representations + + NOTE: it is possible for a graph to have 0 nodes. In this case, the representation will be + a zero vector of length `d` in the final output. + + Parameters + ---------- + H : Tensor + a tensor of shape ``V x d`` containing the batched node-level representations of ``b`` + graphs + batch : Tensor + a tensor of shape ``V`` containing the index of the graph a given vertex corresponds to + + Returns + ------- + Tensor + a tensor of shape ``b x d`` containing the graph-level representations + """ + + +AggregationRegistry = ClassRegistry[Aggregation]() + + +@AggregationRegistry.register("mean") +class MeanAggregation(Aggregation): + r"""Average the graph-level representation: + + .. math:: + \mathbf h = \frac{1}{|V|} \sum_{v \in V} \mathbf h_v + """ + + def forward(self, H: Tensor, batch: Tensor) -> Tensor: + index_torch = batch.unsqueeze(1).repeat(1, H.shape[1]) + dim_size = batch.max().int() + 1 + return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( + self.dim, index_torch, H, reduce="mean", include_self=False + ) + + +@AggregationRegistry.register("sum") +class SumAggregation(Aggregation): + r"""Sum the graph-level representation: + + .. math:: + \mathbf h = \sum_{v \in V} \mathbf h_v + + """ + + def forward(self, H: Tensor, batch: Tensor) -> Tensor: + index_torch = batch.unsqueeze(1).repeat(1, H.shape[1]) + dim_size = batch.max().int() + 1 + return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( + self.dim, index_torch, H, reduce="sum", include_self=False + ) + + +@AggregationRegistry.register("norm") +class NormAggregation(SumAggregation): + r"""Sum the graph-level representation and divide by a normalization constant: + + .. math:: + \mathbf h = \frac{1}{c} \sum_{v \in V} \mathbf h_v + """ + + def __init__(self, dim: int = 0, *args, norm: float = 100.0, **kwargs): + super().__init__(dim, **kwargs) + + self.norm = norm + self.hparams["norm"] = norm + + def forward(self, H: Tensor, batch: Tensor) -> Tensor: + return super().forward(H, batch) / self.norm + + +class AttentiveAggregation(Aggregation): + def __init__(self, dim: int = 0, *args, output_size: int, **kwargs): + super().__init__(dim, *args, **kwargs) + + self.hparams["output_size"] = output_size + self.W = nn.Linear(output_size, 1) + + def forward(self, H: Tensor, batch: Tensor) -> Tensor: + dim_size = batch.max().int() + 1 + attention_logits = self.W(H).exp() + Z = torch.zeros(dim_size, 1, dtype=H.dtype, device=H.device).scatter_reduce_( + self.dim, batch.unsqueeze(1), attention_logits, reduce="sum", include_self=False + ) + alphas = attention_logits / Z[batch] + index_torch = batch.unsqueeze(1).repeat(1, H.shape[1]) + return torch.zeros(dim_size, H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( + self.dim, index_torch, alphas * H, reduce="sum", include_self=False + ) diff --git a/chemprop-updated/chemprop/nn/ffn.py b/chemprop-updated/chemprop/nn/ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a02fe92391adeca9c88ec371296951a0132928 --- /dev/null +++ b/chemprop-updated/chemprop/nn/ffn.py @@ -0,0 +1,63 @@ +from abc import abstractmethod + +from torch import Tensor, nn + +from chemprop.nn.utils import get_activation_function + + +class FFN(nn.Module): + r"""A :class:`FFN` is a differentiable function + :math:`f_\theta : \mathbb R^i \mapsto \mathbb R^o`""" + + input_dim: int + output_dim: int + + @abstractmethod + def forward(self, X: Tensor) -> Tensor: + pass + + +class MLP(nn.Sequential, FFN): + r"""An :class:`MLP` is an FFN that implements the following function: + + .. math:: + \mathbf h_0 &= \mathbf W_0 \mathbf x \,+ \mathbf b_{0} \\ + \mathbf h_l &= \mathbf W_l \left( \mathtt{dropout} \left( \sigma ( \,\mathbf h_{l-1}\, ) \right) \right) + \mathbf b_l\\ + + where :math:`\mathbf x` is the input tensor, :math:`\mathbf W_l` and :math:`\mathbf b_l` + are the learned weight matrix and bias, respectively, of the :math:`l`-th layer, + :math:`\mathbf h_l` is the hidden representation after layer :math:`l`, and :math:`\sigma` + is the activation function. + """ + + @classmethod + def build( + cls, + input_dim: int, + output_dim: int, + hidden_dim: int = 300, + n_layers: int = 1, + dropout: float = 0.0, + activation: str = "relu", + ): + dropout = nn.Dropout(dropout) + act = get_activation_function(activation) + dims = [input_dim] + [hidden_dim] * n_layers + [output_dim] + blocks = [nn.Sequential(nn.Linear(dims[0], dims[1]))] + if len(dims) > 2: + blocks.extend( + [ + nn.Sequential(act, dropout, nn.Linear(d1, d2)) + for d1, d2 in zip(dims[1:-1], dims[2:]) + ] + ) + + return cls(*blocks) + + @property + def input_dim(self) -> int: + return self[0][-1].in_features + + @property + def output_dim(self) -> int: + return self[-1][-1].out_features diff --git a/chemprop-updated/chemprop/nn/hparams.py b/chemprop-updated/chemprop/nn/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..ffa17ab80c16bbae47b35a18d8fe9f3eb66ee590 --- /dev/null +++ b/chemprop-updated/chemprop/nn/hparams.py @@ -0,0 +1,38 @@ +from typing import Protocol, Type, TypedDict + + +class HParamsDict(TypedDict): + """A dictionary containing a module's class and it's hyperparameters + + Using this type should essentially allow for initializing a module via:: + + module = hparams.pop('cls')(**hparams) + """ + + cls: Type + + +class HasHParams(Protocol): + """:class:`HasHParams` is a protocol for clases which possess an :attr:`hparams` attribute which is a dictionary containing the object's class and arguments required to initialize it. + + That is, any object which implements :class:`HasHParams` should be able to be initialized via:: + + class Foo(HasHParams): + def __init__(self, *args, **kwargs): + ... + + foo1 = Foo(...) + foo1_cls = foo1.hparams['cls'] + foo1_kwargs = {k: v for k, v in foo1.hparams.items() if k != "cls"} + foo2 = foo1_cls(**foo1_kwargs) + # code to compare foo1 and foo2 goes here and they should be equal + """ + + hparams: HParamsDict + + +def from_hparams(hparams: HParamsDict): + cls = hparams["cls"] + kwargs = {k: v for k, v in hparams.items() if k != "cls"} + + return cls(**kwargs) diff --git a/chemprop-updated/chemprop/nn/message_passing/__init__.py b/chemprop-updated/chemprop/nn/message_passing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..97078653c6524d5645d4862a8ef683a2d8eb457e --- /dev/null +++ b/chemprop-updated/chemprop/nn/message_passing/__init__.py @@ -0,0 +1,10 @@ +from .base import AtomMessagePassing, BondMessagePassing +from .multi import MulticomponentMessagePassing +from .proto import MessagePassing + +__all__ = [ + "MessagePassing", + "AtomMessagePassing", + "BondMessagePassing", + "MulticomponentMessagePassing", +] diff --git a/chemprop-updated/chemprop/nn/message_passing/base.py b/chemprop-updated/chemprop/nn/message_passing/base.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb14b0f51c97ba083e26b402db7a0c77023a8db --- /dev/null +++ b/chemprop-updated/chemprop/nn/message_passing/base.py @@ -0,0 +1,319 @@ +from abc import abstractmethod + +from lightning.pytorch.core.mixins import HyperparametersMixin +import torch +from torch import Tensor, nn + +from chemprop.conf import DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM, DEFAULT_HIDDEN_DIM +from chemprop.data import BatchMolGraph +from chemprop.exceptions import InvalidShapeError +from chemprop.nn.message_passing.proto import MessagePassing +from chemprop.nn.transforms import GraphTransform, ScaleTransform +from chemprop.nn.utils import Activation, get_activation_function + + +class _MessagePassingBase(MessagePassing, HyperparametersMixin): + """The base message-passing block for atom- and bond-based message-passing schemes + + NOTE: this class is an abstract base class and cannot be instantiated + + Parameters + ---------- + d_v : int, default=DEFAULT_ATOM_FDIM + the feature dimension of the vertices + d_e : int, default=DEFAULT_BOND_FDIM + the feature dimension of the edges + d_h : int, default=DEFAULT_HIDDEN_DIM + the hidden dimension during message passing + bias : bool, defuault=False + if `True`, add a bias term to the learned weight matrices + depth : int, default=3 + the number of message passing iterations + undirected : bool, default=False + if `True`, pass messages on undirected edges + dropout : float, default=0.0 + the dropout probability + activation : str, default="relu" + the activation function to use + d_vd : int | None, default=None + the dimension of additional vertex descriptors that will be concatenated to the hidden features before readout + + See also + -------- + * :class:`AtomMessagePassing` + + * :class:`BondMessagePassing` + """ + + def __init__( + self, + d_v: int = DEFAULT_ATOM_FDIM, + d_e: int = DEFAULT_BOND_FDIM, + d_h: int = DEFAULT_HIDDEN_DIM, + bias: bool = False, + depth: int = 3, + dropout: float = 0.0, + activation: str | Activation = Activation.RELU, + undirected: bool = False, + d_vd: int | None = None, + V_d_transform: ScaleTransform | None = None, + graph_transform: GraphTransform | None = None, + # layers_per_message: int = 1, + ): + super().__init__() + # manually add V_d_transform and graph_transform to hparams to suppress lightning's warning + # about double saving their state_dict values. + self.save_hyperparameters(ignore=["V_d_transform", "graph_transform"]) + self.hparams["V_d_transform"] = V_d_transform + self.hparams["graph_transform"] = graph_transform + self.hparams["cls"] = self.__class__ + + self.W_i, self.W_h, self.W_o, self.W_d = self.setup(d_v, d_e, d_h, d_vd, bias) + self.depth = depth + self.undirected = undirected + self.dropout = nn.Dropout(dropout) + self.tau = get_activation_function(activation) + self.V_d_transform = V_d_transform if V_d_transform is not None else nn.Identity() + self.graph_transform = graph_transform if graph_transform is not None else nn.Identity() + + @property + def output_dim(self) -> int: + return self.W_d.out_features if self.W_d is not None else self.W_o.out_features + + @abstractmethod + def setup( + self, + d_v: int = DEFAULT_ATOM_FDIM, + d_e: int = DEFAULT_BOND_FDIM, + d_h: int = DEFAULT_HIDDEN_DIM, + d_vd: int | None = None, + bias: bool = False, + ) -> tuple[nn.Module, nn.Module, nn.Module, nn.Module | None]: + """setup the weight matrices used in the message passing update functions + + Parameters + ---------- + d_v : int + the vertex feature dimension + d_e : int + the edge feature dimension + d_h : int, default=300 + the hidden dimension during message passing + d_vd : int | None, default=None + the dimension of additional vertex descriptors that will be concatenated to the hidden + features before readout, if any + bias: bool, default=False + whether to add a learned bias to the matrices + + Returns + ------- + W_i, W_h, W_o, W_d : tuple[nn.Module, nn.Module, nn.Module, nn.Module | None] + the input, hidden, output, and descriptor weight matrices, respectively, used in the + message passing update functions. The descriptor weight matrix is `None` if no vertex + dimension is supplied + """ + + @abstractmethod + def initialize(self, bmg: BatchMolGraph) -> Tensor: + """initialize the message passing scheme by calculating initial matrix of hidden features""" + + @abstractmethod + def message(self, H_t: Tensor, bmg: BatchMolGraph): + """Calculate the message matrix""" + + def update(self, M_t, H_0): + """Calcualte the updated hidden for each edge""" + H_t = self.W_h(M_t) + H_t = self.tau(H_0 + H_t) + H_t = self.dropout(H_t) + + return H_t + + def finalize(self, M: Tensor, V: Tensor, V_d: Tensor | None) -> Tensor: + r"""Finalize message passing by (1) concatenating the final message ``M`` and the original + vertex features ``V`` and (2) if provided, further concatenating additional vertex + descriptors ``V_d``. + + This function implements the following operation: + + .. math:: + H &= \mathtt{dropout} \left( \tau(\mathbf{W}_o(V \mathbin\Vert M)) \right) \\ + H &= \mathtt{dropout} \left( \tau(\mathbf{W}_d(H \mathbin\Vert V_d)) \right), + + where :math:`\tau` is the activation function, :math:`\Vert` is the concatenation operator, + :math:`\mathbf{W}_o` and :math:`\mathbf{W}_d` are learned weight matrices, :math:`M` is + the message matrix, :math:`V` is the original vertex feature matrix, and :math:`V_d` is an + optional vertex descriptor matrix. + + Parameters + ---------- + M : Tensor + a tensor of shape ``V x d_h`` containing the message vector of each vertex + V : Tensor + a tensor of shape ``V x d_v`` containing the original vertex features + V_d : Tensor | None + an optional tensor of shape ``V x d_vd`` containing additional vertex descriptors + + Returns + ------- + Tensor + a tensor of shape ``V x (d_h + d_v [+ d_vd])`` containing the final hidden + representations + + Raises + ------ + InvalidShapeError + if ``V_d`` is not of shape ``b x d_vd``, where ``b`` is the batch size and ``d_vd`` is + the vertex descriptor dimension + """ + H = self.W_o(torch.cat((V, M), dim=1)) # V x d_o + H = self.tau(H) + H = self.dropout(H) + + if V_d is not None: + V_d = self.V_d_transform(V_d) + try: + H = self.W_d(torch.cat((H, V_d), dim=1)) # V x (d_o + d_vd) + H = self.dropout(H) + except RuntimeError: + raise InvalidShapeError( + "V_d", V_d.shape, [len(H), self.W_d.in_features - self.W_o.out_features] + ) + + return H + + def forward(self, bmg: BatchMolGraph, V_d: Tensor | None = None) -> Tensor: + """Encode a batch of molecular graphs. + + Parameters + ---------- + bmg: BatchMolGraph + a batch of :class:`BatchMolGraph`s to encode + V_d : Tensor | None, default=None + an optional tensor of shape ``V x d_vd`` containing additional descriptors for each atom + in the batch. These will be concatenated to the learned atomic descriptors and + transformed before the readout phase. + + Returns + ------- + Tensor + a tensor of shape ``V x d_h`` or ``V x (d_h + d_vd)`` containing the encoding of each + molecule in the batch, depending on whether additional atom descriptors were provided + """ + bmg = self.graph_transform(bmg) + H_0 = self.initialize(bmg) + + H = self.tau(H_0) + for _ in range(1, self.depth): + if self.undirected: + H = (H + H[bmg.rev_edge_index]) / 2 + + M = self.message(H, bmg) + H = self.update(M, H_0) + + index_torch = bmg.edge_index[1].unsqueeze(1).repeat(1, H.shape[1]) + M = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( + 0, index_torch, H, reduce="sum", include_self=False + ) + return self.finalize(M, bmg.V, V_d) + + +class BondMessagePassing(_MessagePassingBase): + r"""A :class:`BondMessagePassing` encodes a batch of molecular graphs by passing messages along + directed bonds. + + It implements the following operation: + + .. math:: + + h_{vw}^{(0)} &= \tau \left( \mathbf W_i(e_{vw}) \right) \\ + m_{vw}^{(t)} &= \sum_{u \in \mathcal N(v)\setminus w} h_{uv}^{(t-1)} \\ + h_{vw}^{(t)} &= \tau \left(h_v^{(0)} + \mathbf W_h m_{vw}^{(t-1)} \right) \\ + m_v^{(T)} &= \sum_{w \in \mathcal N(v)} h_w^{(T-1)} \\ + h_v^{(T)} &= \tau \left (\mathbf W_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right), + + where :math:`\tau` is the activation function; :math:`\mathbf W_i`, :math:`\mathbf W_h`, and + :math:`\mathbf W_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the + bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`; + :math:`h_{vw}^{(t)}` is the hidden representation of the bond :math:`v \rightarrow w` at + iteration :math:`t`; :math:`m_{vw}^{(t)}` is the message received by the bond :math:`v + \to w` at iteration :math:`t`; and :math:`t \in \{1, \dots, T-1\}` is the number of + message passing iterations. + """ + + def setup( + self, + d_v: int = DEFAULT_ATOM_FDIM, + d_e: int = DEFAULT_BOND_FDIM, + d_h: int = DEFAULT_HIDDEN_DIM, + d_vd: int | None = None, + bias: bool = False, + ): + W_i = nn.Linear(d_v + d_e, d_h, bias) + W_h = nn.Linear(d_h, d_h, bias) + W_o = nn.Linear(d_v + d_h, d_h) + # initialize W_d only when d_vd is neither 0 nor None + W_d = nn.Linear(d_h + d_vd, d_h + d_vd) if d_vd else None + + return W_i, W_h, W_o, W_d + + def initialize(self, bmg: BatchMolGraph) -> Tensor: + return self.W_i(torch.cat([bmg.V[bmg.edge_index[0]], bmg.E], dim=1)) + + def message(self, H: Tensor, bmg: BatchMolGraph) -> Tensor: + index_torch = bmg.edge_index[1].unsqueeze(1).repeat(1, H.shape[1]) + M_all = torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( + 0, index_torch, H, reduce="sum", include_self=False + )[bmg.edge_index[0]] + M_rev = H[bmg.rev_edge_index] + + return M_all - M_rev + + +class AtomMessagePassing(_MessagePassingBase): + r"""A :class:`AtomMessagePassing` encodes a batch of molecular graphs by passing messages along + atoms. + + It implements the following operation: + + .. math:: + + h_v^{(0)} &= \tau \left( \mathbf{W}_i(x_v) \right) \\ + m_v^{(t)} &= \sum_{u \in \mathcal{N}(v)} h_u^{(t-1)} \mathbin\Vert e_{uv} \\ + h_v^{(t)} &= \tau\left(h_v^{(0)} + \mathbf{W}_h m_v^{(t-1)}\right) \\ + m_v^{(T)} &= \sum_{w \in \mathcal{N}(v)} h_w^{(T-1)} \\ + h_v^{(T)} &= \tau \left (\mathbf{W}_o \left( x_v \mathbin\Vert m_{v}^{(T)} \right) \right), + + where :math:`\tau` is the activation function; :math:`\mathbf{W}_i`, :math:`\mathbf{W}_h`, and + :math:`\mathbf{W}_o` are learned weight matrices; :math:`e_{vw}` is the feature vector of the + bond between atoms :math:`v` and :math:`w`; :math:`x_v` is the feature vector of atom :math:`v`; + :math:`h_v^{(t)}` is the hidden representation of atom :math:`v` at iteration :math:`t`; + :math:`m_v^{(t)}` is the message received by atom :math:`v` at iteration :math:`t`; and + :math:`t \in \{1, \dots, T\}` is the number of message passing iterations. + """ + + def setup( + self, + d_v: int = DEFAULT_ATOM_FDIM, + d_e: int = DEFAULT_BOND_FDIM, + d_h: int = DEFAULT_HIDDEN_DIM, + d_vd: int | None = None, + bias: bool = False, + ): + W_i = nn.Linear(d_v, d_h, bias) + W_h = nn.Linear(d_e + d_h, d_h, bias) + W_o = nn.Linear(d_v + d_h, d_h) + # initialize W_d only when d_vd is neither 0 nor None + W_d = nn.Linear(d_h + d_vd, d_h + d_vd) if d_vd else None + + return W_i, W_h, W_o, W_d + + def initialize(self, bmg: BatchMolGraph) -> Tensor: + return self.W_i(bmg.V[bmg.edge_index[0]]) + + def message(self, H: Tensor, bmg: BatchMolGraph): + H = torch.cat((H, bmg.E), dim=1) + index_torch = bmg.edge_index[1].unsqueeze(1).repeat(1, H.shape[1]) + return torch.zeros(len(bmg.V), H.shape[1], dtype=H.dtype, device=H.device).scatter_reduce_( + 0, index_torch, H, reduce="sum", include_self=False + )[bmg.edge_index[0]] diff --git a/chemprop-updated/chemprop/nn/message_passing/multi.py b/chemprop-updated/chemprop/nn/message_passing/multi.py new file mode 100644 index 0000000000000000000000000000000000000000..98a9cb84c55dbba5c8fa50ac56cac0c30a2171b8 --- /dev/null +++ b/chemprop-updated/chemprop/nn/message_passing/multi.py @@ -0,0 +1,80 @@ +import logging +from typing import Iterable, Sequence + +from torch import Tensor, nn + +from chemprop.data import BatchMolGraph +from chemprop.nn.hparams import HasHParams +from chemprop.nn.message_passing.proto import MessagePassing + +logger = logging.getLogger(__name__) + + +class MulticomponentMessagePassing(nn.Module, HasHParams): + """A `MulticomponentMessagePassing` performs message-passing on each individual input in a + multicomponent input then concatenates the representation of each input to construct a + global representation + + Parameters + ---------- + blocks : Sequence[MessagePassing] + the invidual message-passing blocks for each input + n_components : int + the number of components in each input + shared : bool, default=False + whether one block will be shared among all components in an input. If not, a separate + block will be learned for each component. + """ + + def __init__(self, blocks: Sequence[MessagePassing], n_components: int, shared: bool = False): + super().__init__() + self.hparams = { + "cls": self.__class__, + "blocks": [block.hparams for block in blocks], + "n_components": n_components, + "shared": shared, + } + + if len(blocks) == 0: + raise ValueError("arg 'blocks' was empty!") + if shared and len(blocks) > 1: + logger.warning( + "More than 1 block was supplied but 'shared' was True! Using only the 0th block..." + ) + elif not shared and len(blocks) != n_components: + raise ValueError( + "arg 'n_components' must be equal to `len(blocks)` if 'shared' is False! " + f"got: {n_components} and {len(blocks)}, respectively." + ) + + self.n_components = n_components + self.shared = shared + self.blocks = nn.ModuleList([blocks[0]] * self.n_components if shared else blocks) + + def __len__(self) -> int: + return len(self.blocks) + + @property + def output_dim(self) -> int: + d_o = sum(block.output_dim for block in self.blocks) + + return d_o + + def forward(self, bmgs: Iterable[BatchMolGraph], V_ds: Iterable[Tensor | None]) -> list[Tensor]: + """Encode the multicomponent inputs + + Parameters + ---------- + bmgs : Iterable[BatchMolGraph] + V_ds : Iterable[Tensor | None] + + Returns + ------- + list[Tensor] + a list of tensors of shape `V x d_i` containing the respective encodings of the `i`\th + component, where `d_i` is the output dimension of the `i`\th encoder + """ + if V_ds is None: + return [block(bmg) for block, bmg in zip(self.blocks, bmgs)] + else: + return [block(bmg, V_d) for block, bmg, V_d in zip(self.blocks, bmgs, V_ds)] diff --git a/chemprop-updated/chemprop/nn/message_passing/proto.py b/chemprop-updated/chemprop/nn/message_passing/proto.py new file mode 100644 index 0000000000000000000000000000000000000000..f00c8a36002c36485da6fbdb08d6137c6d954765 --- /dev/null +++ b/chemprop-updated/chemprop/nn/message_passing/proto.py @@ -0,0 +1,35 @@ +from abc import abstractmethod + +from torch import Tensor, nn + +from chemprop.data import BatchMolGraph +from chemprop.nn.hparams import HasHParams + + +class MessagePassing(nn.Module, HasHParams): + """A :class:`MessagePassing` module encodes a batch of molecular graphs + using message passing to learn vertex-level hidden representations.""" + + input_dim: int + output_dim: int + + @abstractmethod + def forward(self, bmg: BatchMolGraph, V_d: Tensor | None = None) -> Tensor: + """Encode a batch of molecular graphs. + + Parameters + ---------- + bmg: BatchMolGraph + the batch of :class:`~chemprop.featurizers.molgraph.MolGraph`\s to encode + V_d : Tensor | None, default=None + an optional tensor of shape `V x d_vd` containing additional descriptors for each atom + in the batch. These will be concatenated to the learned atomic descriptors and + transformed before the readout phase. + + Returns + ------- + Tensor + a tensor of shape `V x d_h` or `V x (d_h + d_vd)` containing the hidden representation + of each vertex in the batch of graphs. The feature dimension depends on whether + additional atom descriptors were provided + """ diff --git a/chemprop-updated/chemprop/nn/metrics.py b/chemprop-updated/chemprop/nn/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..cbbb346f4f9e61a11d97e857aee8c75de28ed44f --- /dev/null +++ b/chemprop-updated/chemprop/nn/metrics.py @@ -0,0 +1,567 @@ +from abc import abstractmethod + +from numpy.typing import ArrayLike +import torch +from torch import Tensor +from torch.nn import functional as F +import torchmetrics +from torchmetrics.utilities.compute import auc +from torchmetrics.utilities.data import dim_zero_cat + +from chemprop.utils.registry import ClassRegistry + +__all__ = [ + "ChempropMetric", + "LossFunctionRegistry", + "MetricRegistry", + "MSE", + "MAE", + "RMSE", + "BoundedMixin", + "BoundedMSE", + "BoundedMAE", + "BoundedRMSE", + "BinaryAccuracy", + "BinaryAUPRC", + "BinaryAUROC", + "BinaryF1Score", + "BinaryMCCMetric", + "BoundedMAE", + "BoundedMSE", + "BoundedRMSE", + "MetricRegistry", + "MulticlassMCCMetric", + "R2Score", + "MVELoss", + "EvidentialLoss", + "BCELoss", + "CrossEntropyLoss", + "BinaryMCCLoss", + "BinaryMCCMetric", + "MulticlassMCCLoss", + "MulticlassMCCMetric", + "ClassificationMixin", + "BinaryAUROC", + "BinaryAUPRC", + "BinaryAccuracy", + "BinaryF1Score", + "DirichletLoss", + "SID", + "Wasserstein", + "QuantileLoss", +] + + +class ChempropMetric(torchmetrics.Metric): + is_differentiable = True + higher_is_better = False + full_state_update = False + + def __init__(self, task_weights: ArrayLike = 1.0): + """ + Parameters + ---------- + task_weights : ArrayLike, default=1.0 + the per-task weights of shape `t` or `1 x t`. Defaults to all tasks having a weight of 1. + """ + super().__init__() + task_weights = torch.as_tensor(task_weights, dtype=torch.float).view(1, -1) + self.register_buffer("task_weights", task_weights) + + self.add_state("total_loss", default=torch.tensor(0.0), dist_reduce_fx="sum") + self.add_state("num_samples", default=torch.tensor(0), dist_reduce_fx="sum") + + def update( + self, + preds: Tensor, + targets: Tensor, + mask: Tensor | None = None, + weights: Tensor | None = None, + lt_mask: Tensor | None = None, + gt_mask: Tensor | None = None, + ) -> None: + """Calculate the mean loss function value given predicted and target values + + Parameters + ---------- + preds : Tensor + a tensor of shape `b x t x u` (regression), `b x t` (binary classification), or + `b x t x c` (multiclass classification) containing the predictions, where `b` is the + batch size, `t` is the number of tasks to predict, `u` is the number of + targets to predict for each task, and `c` is the number of classes. + targets : Tensor + a float tensor of shape `b x t` containing the target values + mask : Tensor + a boolean tensor of shape `b x t` indicating whether the given prediction should be + included in the loss calculation + weights : Tensor + a tensor of shape `b` or `b x 1` containing the per-sample weight + lt_mask: Tensor + gt_mask: Tensor + """ + mask = torch.ones_like(targets, dtype=torch.bool) if mask is None else mask + weights = torch.ones_like(targets, dtype=torch.float) if weights is None else weights + lt_mask = torch.zeros_like(targets, dtype=torch.bool) if lt_mask is None else lt_mask + gt_mask = torch.zeros_like(targets, dtype=torch.bool) if gt_mask is None else gt_mask + + L = self._calc_unreduced_loss(preds, targets, mask, weights, lt_mask, gt_mask) + L = L * weights.view(-1, 1) * self.task_weights * mask + + self.total_loss += L.sum() + self.num_samples += mask.sum() + + def compute(self): + return self.total_loss / self.num_samples + + @abstractmethod + def _calc_unreduced_loss(self, preds, targets, mask, weights, lt_mask, gt_mask) -> Tensor: + """Calculate a tensor of shape `b x t` containing the unreduced loss values.""" + + def extra_repr(self) -> str: + return f"task_weights={self.task_weights.tolist()}" + + +LossFunctionRegistry = ClassRegistry[ChempropMetric]() +MetricRegistry = ClassRegistry[ChempropMetric]() + + +@LossFunctionRegistry.register("mse") +@MetricRegistry.register("mse") +class MSE(ChempropMetric): + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, *args) -> Tensor: + return F.mse_loss(preds, targets, reduction="none") + + +@MetricRegistry.register("mae") +@LossFunctionRegistry.register("mae") +class MAE(ChempropMetric): + def _calc_unreduced_loss(self, preds, targets, *args) -> Tensor: + return (preds - targets).abs() + + +@LossFunctionRegistry.register("rmse") +@MetricRegistry.register("rmse") +class RMSE(MSE): + def compute(self): + return (self.total_loss / self.num_samples).sqrt() + + +class BoundedMixin: + def _calc_unreduced_loss(self, preds, targets, mask, weights, lt_mask, gt_mask) -> Tensor: + preds = torch.where((preds < targets) & lt_mask, targets, preds) + preds = torch.where((preds > targets) & gt_mask, targets, preds) + + return super()._calc_unreduced_loss(preds, targets, mask, weights) + + +@LossFunctionRegistry.register("bounded-mse") +@MetricRegistry.register("bounded-mse") +class BoundedMSE(BoundedMixin, MSE): + pass + + +@LossFunctionRegistry.register("bounded-mae") +@MetricRegistry.register("bounded-mae") +class BoundedMAE(BoundedMixin, MAE): + pass + + +@LossFunctionRegistry.register("bounded-rmse") +@MetricRegistry.register("bounded-rmse") +class BoundedRMSE(BoundedMixin, RMSE): + pass + + +@MetricRegistry.register("r2") +class R2Score(torchmetrics.R2Score): + def __init__(self, task_weights: ArrayLike = 1.0, **kwargs): + """ + Parameters + ---------- + task_weights : ArrayLike = 1.0 + .. important:: + Ignored. Maintained for compatibility with :class:`ChempropMetric` + """ + super().__init__() + task_weights = torch.as_tensor(task_weights, dtype=torch.float).view(1, -1) + self.register_buffer("task_weights", task_weights) + + def update(self, preds: Tensor, targets: Tensor, mask: Tensor, *args, **kwargs): + super().update(preds[mask], targets[mask]) + + +@LossFunctionRegistry.register("mve") +class MVELoss(ChempropMetric): + """Calculate the loss using Eq. 9 from [nix1994]_ + + References + ---------- + .. [nix1994] Nix, D. A.; Weigend, A. S. "Estimating the mean and variance of the target + probability distribution." Proceedings of 1994 IEEE International Conference on Neural + Networks, 1994 https://doi.org/10.1109/icnn.1994.374138 + """ + + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, *args) -> Tensor: + mean, var = torch.unbind(preds, dim=-1) + + L_sos = (mean - targets) ** 2 / (2 * var) + L_kl = (2 * torch.pi * var).log() / 2 + + return L_sos + L_kl + + +@LossFunctionRegistry.register("evidential") +class EvidentialLoss(ChempropMetric): + """Calculate the loss using Eqs. 8, 9, and 10 from [amini2020]_. See also [soleimany2021]_. + + References + ---------- + .. [amini2020] Amini, A; Schwarting, W.; Soleimany, A.; Rus, D.; + "Deep Evidential Regression" Advances in Neural Information Processing Systems; 2020; Vol.33. + https://proceedings.neurips.cc/paper_files/paper/2020/file/aab085461de182608ee9f607f3f7d18f-Paper.pdf + .. [soleimany2021] Soleimany, A.P.; Amini, A.; Goldman, S.; Rus, D.; Bhatia, S.N.; Coley, C.W.; + "Evidential Deep Learning for Guided Molecular Property Prediction and Discovery." ACS + Cent. Sci. 2021, 7, 8, 1356-1367. https://doi.org/10.1021/acscentsci.1c00546 + """ + + def __init__(self, task_weights: ArrayLike = 1.0, v_kl: float = 0.2, eps: float = 1e-8): + super().__init__(task_weights) + self.v_kl = v_kl + self.eps = eps + + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, *args) -> Tensor: + mean, v, alpha, beta = torch.unbind(preds, dim=-1) + + residuals = targets - mean + twoBlambda = 2 * beta * (1 + v) + + L_nll = ( + 0.5 * (torch.pi / v).log() + - alpha * twoBlambda.log() + + (alpha + 0.5) * torch.log(v * residuals**2 + twoBlambda) + + torch.lgamma(alpha) + - torch.lgamma(alpha + 0.5) + ) + + L_reg = (2 * v + alpha) * residuals.abs() + + return L_nll + self.v_kl * (L_reg - self.eps) + + def extra_repr(self) -> str: + parent_repr = super().extra_repr() + return parent_repr + f", v_kl={self.v_kl}, eps={self.eps}" + + +@LossFunctionRegistry.register("bce") +class BCELoss(ChempropMetric): + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, *args) -> Tensor: + return F.binary_cross_entropy_with_logits(preds, targets, reduction="none") + + +@LossFunctionRegistry.register("ce") +class CrossEntropyLoss(ChempropMetric): + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, *args) -> Tensor: + preds = preds.transpose(1, 2) + targets = targets.long() + + return F.cross_entropy(preds, targets, reduction="none") + + +@LossFunctionRegistry.register("binary-mcc") +class BinaryMCCLoss(ChempropMetric): + def __init__(self, task_weights: ArrayLike = 1.0): + """ + Parameters + ---------- + task_weights : ArrayLike, default=1.0 + the per-task weights of shape `t` or `1 x t`. Defaults to all tasks having a weight of 1. + """ + super().__init__(task_weights) + + self.add_state("TP", default=[], dist_reduce_fx="cat") + self.add_state("FP", default=[], dist_reduce_fx="cat") + self.add_state("TN", default=[], dist_reduce_fx="cat") + self.add_state("FN", default=[], dist_reduce_fx="cat") + + def update( + self, + preds: Tensor, + targets: Tensor, + mask: Tensor | None = None, + weights: Tensor | None = None, + *args, + ): + mask = torch.ones_like(targets, dtype=torch.bool) if mask is None else mask + weights = torch.ones_like(targets, dtype=torch.float) if weights is None else weights + + if not (0 <= preds.min() and preds.max() <= 1): # assume logits + preds = preds.sigmoid() + + TP, FP, TN, FN = self._calc_unreduced_loss(preds, targets.long(), mask, weights, *args) + + self.TP += [TP] + self.FP += [FP] + self.TN += [TN] + self.FN += [FN] + + def _calc_unreduced_loss(self, preds, targets, mask, weights, *args) -> Tensor: + TP = (targets * preds * weights * mask).sum(0, keepdim=True) + FP = ((1 - targets) * preds * weights * mask).sum(0, keepdim=True) + TN = ((1 - targets) * (1 - preds) * weights * mask).sum(0, keepdim=True) + FN = (targets * (1 - preds) * weights * mask).sum(0, keepdim=True) + + return TP, FP, TN, FN + + def compute(self): + TP = dim_zero_cat(self.TP).sum(0) + FP = dim_zero_cat(self.FP).sum(0) + TN = dim_zero_cat(self.TN).sum(0) + FN = dim_zero_cat(self.FN).sum(0) + + MCC = (TP * TN - FP * FN) / ((TP + FP) * (TP + FN) * (TN + FP) * (TN + FN) + 1e-8).sqrt() + MCC = MCC * self.task_weights + return 1 - MCC.mean() + + +@MetricRegistry.register("binary-mcc") +class BinaryMCCMetric(BinaryMCCLoss): + def compute(self): + return 1 - super().compute() + + +@LossFunctionRegistry.register("multiclass-mcc") +class MulticlassMCCLoss(ChempropMetric): + """Calculate a soft Matthews correlation coefficient ([mccWiki]_) loss for multiclass + classification based on the implementataion of [mccSklearn]_ + References + ---------- + .. [mccWiki] https://en.wikipedia.org/wiki/Phi_coefficient#Multiclass_case + .. [mccSklearn] https://scikit-learn.org/stable/modules/generated/sklearn.metrics.matthews_corrcoef.html + """ + + def __init__(self, task_weights: ArrayLike = 1.0): + """ + Parameters + ---------- + task_weights : ArrayLike, default=1.0 + the per-task weights of shape `t` or `1 x t`. Defaults to all tasks having a weight of 1. + """ + super().__init__(task_weights) + + self.add_state("p", default=[], dist_reduce_fx="cat") + self.add_state("t", default=[], dist_reduce_fx="cat") + self.add_state("c", default=[], dist_reduce_fx="cat") + self.add_state("s", default=[], dist_reduce_fx="cat") + + def update( + self, + preds: Tensor, + targets: Tensor, + mask: Tensor | None = None, + weights: Tensor | None = None, + *args, + ): + mask = torch.ones_like(targets, dtype=torch.bool) if mask is None else mask + weights = ( + torch.ones_like(targets, dtype=torch.float) if weights is None else weights.view(-1, 1) + ) + + if not (0 <= preds.min() and preds.max() <= 1): # assume logits + preds = preds.softmax(2) + + p, t, c, s = self._calc_unreduced_loss(preds, targets.long(), mask, weights, *args) + + self.p += [p] + self.t += [t] + self.c += [c] + self.s += [s] + + def _calc_unreduced_loss(self, preds, targets, mask, weights, *args) -> Tensor: + device = preds.device + C = preds.shape[2] + bin_targets = torch.eye(C, device=device)[targets] + bin_preds = torch.eye(C, device=device)[preds.argmax(-1)] + masked_data_weights = weights.unsqueeze(2) * mask.unsqueeze(2) + p = (bin_preds * masked_data_weights).sum(0, keepdims=True) + t = (bin_targets * masked_data_weights).sum(0, keepdims=True) + c = (bin_preds * bin_targets * masked_data_weights).sum(2).sum(0, keepdims=True) + s = (preds * masked_data_weights).sum(2).sum(0, keepdims=True) + + return p, t, c, s + + def compute(self): + p = dim_zero_cat(self.p).sum(0) + t = dim_zero_cat(self.t).sum(0) + c = dim_zero_cat(self.c).sum(0) + s = dim_zero_cat(self.s).sum(0) + s2 = s.square() + + # the `einsum` calls amount to calculating the batched dot product + cov_ytyp = c * s - torch.einsum("ij,ij->i", p, t) + cov_ypyp = s2 - torch.einsum("ij,ij->i", p, p) + cov_ytyt = s2 - torch.einsum("ij,ij->i", t, t) + + x = cov_ypyp * cov_ytyt + MCC = torch.where(x == 0, torch.tensor(0.0), cov_ytyp / x.sqrt()) + MCC = MCC * self.task_weights + + return 1 - MCC.mean() + + +@MetricRegistry.register("multiclass-mcc") +class MulticlassMCCMetric(MulticlassMCCLoss): + def compute(self): + return 1 - super().compute() + + +class ClassificationMixin: + def __init__(self, task_weights: ArrayLike = 1.0, **kwargs): + """ + Parameters + ---------- + task_weights : ArrayLike = 1.0 + .. important:: + Ignored. Maintained for compatibility with :class:`ChempropMetric` + """ + super().__init__() + task_weights = torch.as_tensor(task_weights, dtype=torch.float).view(1, -1) + self.register_buffer("task_weights", task_weights) + + def update(self, preds: Tensor, targets: Tensor, mask: Tensor, *args, **kwargs): + super().update(preds[mask], targets[mask].long()) + + +@MetricRegistry.register("roc") +class BinaryAUROC(ClassificationMixin, torchmetrics.classification.BinaryAUROC): + pass + + +@MetricRegistry.register("prc") +class BinaryAUPRC(ClassificationMixin, torchmetrics.classification.BinaryPrecisionRecallCurve): + def compute(self) -> Tensor: + p, r, _ = super().compute() + return auc(r, p) + + +@MetricRegistry.register("accuracy") +class BinaryAccuracy(ClassificationMixin, torchmetrics.classification.BinaryAccuracy): + pass + + +@MetricRegistry.register("f1") +class BinaryF1Score(ClassificationMixin, torchmetrics.classification.BinaryF1Score): + pass + + +@LossFunctionRegistry.register("dirichlet") +class DirichletLoss(ChempropMetric): + """Uses the loss function from [sensoy2018]_ based on the implementation at [sensoyGithub]_ + + References + ---------- + .. [sensoy2018] Sensoy, M.; Kaplan, L.; Kandemir, M. "Evidential deep learning to quantify + classification uncertainty." NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768 + .. [sensoyGithub] https://muratsensoy.github.io/uncertainty.html#Define-the-loss-function + """ + + def __init__(self, task_weights: ArrayLike = 1.0, v_kl: float = 0.2): + super().__init__(task_weights) + self.v_kl = v_kl + + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, *args) -> Tensor: + targets = torch.eye(preds.shape[2], device=preds.device)[targets.long()] + + S = preds.sum(-1, keepdim=True) + p = preds / S + + A = (targets - p).square().sum(-1, keepdim=True) + B = ((p * (1 - p)) / (S + 1)).sum(-1, keepdim=True) + + L_mse = A + B + + alpha = targets + (1 - targets) * preds + beta = torch.ones_like(alpha) + S_alpha = alpha.sum(-1, keepdim=True) + S_beta = beta.sum(-1, keepdim=True) + + ln_alpha = S_alpha.lgamma() - alpha.lgamma().sum(-1, keepdim=True) + ln_beta = beta.lgamma().sum(-1, keepdim=True) - S_beta.lgamma() + + dg0 = torch.digamma(alpha) + dg1 = torch.digamma(S_alpha) + + L_kl = ln_alpha + ln_beta + torch.sum((alpha - beta) * (dg0 - dg1), -1, keepdim=True) + + return (L_mse + self.v_kl * L_kl).mean(-1) + + def extra_repr(self) -> str: + return f"v_kl={self.v_kl}" + + +@LossFunctionRegistry.register("sid") +class SID(ChempropMetric): + def __init__(self, task_weights: ArrayLike = 1.0, threshold: float | None = None, **kwargs): + super().__init__(task_weights, **kwargs) + + self.threshold = threshold + + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, mask: Tensor, *args) -> Tensor: + if self.threshold is not None: + preds = preds.clamp(min=self.threshold) + + preds_norm = preds / (preds * mask).sum(1, keepdim=True) + + targets = targets.masked_fill(~mask, 1) + preds_norm = preds_norm.masked_fill(~mask, 1) + + return (preds_norm / targets).log() * preds_norm + (targets / preds_norm).log() * targets + + def extra_repr(self) -> str: + return f"threshold={self.threshold}" + + +@LossFunctionRegistry.register(["earthmovers", "wasserstein"]) +class Wasserstein(ChempropMetric): + def __init__(self, task_weights: ArrayLike = 1.0, threshold: float | None = None): + super().__init__(task_weights) + + self.threshold = threshold + + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, mask: Tensor, *args) -> Tensor: + if self.threshold is not None: + preds = preds.clamp(min=self.threshold) + + preds_norm = preds / (preds * mask).sum(1, keepdim=True) + + return (targets.cumsum(1) - preds_norm.cumsum(1)).abs() + + def extra_repr(self) -> str: + return f"threshold={self.threshold}" + + +@LossFunctionRegistry.register(["quantile", "pinball"]) +class QuantileLoss(ChempropMetric): + def __init__(self, task_weights: ArrayLike = 1.0, alpha: float = 0.1): + super().__init__(task_weights) + self.alpha = alpha + + bounds = torch.tensor([-1 / 2, 1 / 2]).view(-1, 1, 1) + tau = torch.tensor([[alpha / 2, 1 - alpha / 2], [alpha / 2 - 1, -alpha / 2]]).view( + 2, 2, 1, 1 + ) + + self.register_buffer("bounds", bounds) + self.register_buffer("tau", tau) + + def _calc_unreduced_loss(self, preds: Tensor, targets: Tensor, mask: Tensor, *args) -> Tensor: + mean, interval = torch.unbind(preds, dim=-1) + + interval_bounds = self.bounds * interval + pred_bounds = mean + interval_bounds + error_bounds = targets - pred_bounds + loss_bounds = (self.tau * error_bounds).amax(0) + + return loss_bounds.sum(0) + + def extra_repr(self) -> str: + return f"alpha={self.alpha}" diff --git a/chemprop-updated/chemprop/nn/predictors.py b/chemprop-updated/chemprop/nn/predictors.py new file mode 100644 index 0000000000000000000000000000000000000000..45d6ed415a7f8b599d1d213f768590c8ee3a8112 --- /dev/null +++ b/chemprop-updated/chemprop/nn/predictors.py @@ -0,0 +1,369 @@ +from abc import abstractmethod + +from lightning.pytorch.core.mixins import HyperparametersMixin +import torch +from torch import Tensor, nn +from torch.nn import functional as F + +from chemprop.conf import DEFAULT_HIDDEN_DIM +from chemprop.nn.ffn import MLP +from chemprop.nn.hparams import HasHParams +from chemprop.nn.metrics import ( + MSE, + SID, + BCELoss, + BinaryAUROC, + ChempropMetric, + CrossEntropyLoss, + DirichletLoss, + EvidentialLoss, + MulticlassMCCMetric, + MVELoss, + QuantileLoss, +) +from chemprop.nn.transforms import UnscaleTransform +from chemprop.utils import ClassRegistry, Factory + +__all__ = [ + "Predictor", + "PredictorRegistry", + "RegressionFFN", + "MveFFN", + "EvidentialFFN", + "BinaryClassificationFFNBase", + "BinaryClassificationFFN", + "BinaryDirichletFFN", + "MulticlassClassificationFFN", + "MulticlassDirichletFFN", + "SpectralFFN", +] + + +class Predictor(nn.Module, HasHParams): + r"""A :class:`Predictor` is a protocol that defines a differentiable function + :math:`f` : \mathbb R^d \mapsto \mathbb R^o""" + + input_dim: int + """the input dimension""" + output_dim: int + """the output dimension""" + n_tasks: int + """the number of tasks `t` to predict for each input""" + n_targets: int + """the number of targets `s` to predict for each task `t`""" + criterion: ChempropMetric + """the loss function to use for training""" + task_weights: Tensor + """the weights to apply to each task when calculating the loss""" + output_transform: UnscaleTransform + """the transform to apply to the output of the predictor""" + + @abstractmethod + def forward(self, Z: Tensor) -> Tensor: + pass + + @abstractmethod + def train_step(self, Z: Tensor) -> Tensor: + pass + + @abstractmethod + def encode(self, Z: Tensor, i: int) -> Tensor: + """Calculate the :attr:`i`-th hidden representation + + Parameters + ---------- + Z : Tensor + a tensor of shape ``n x d`` containing the input data to encode, where ``d`` is the + input dimensionality. + i : int + The stop index of slice of the MLP used to encode the input. That is, use all + layers in the MLP *up to* :attr:`i` (i.e., ``MLP[:i]``). This can be any integer + value, and the behavior of this function is dependent on the underlying list + slicing behavior. For example: + + * ``i=0``: use a 0-layer MLP (i.e., a no-op) + * ``i=1``: use only the first block + * ``i=-1``: use *up to* the final block + + Returns + ------- + Tensor + a tensor of shape ``n x h`` containing the :attr:`i`-th hidden representation, where + ``h`` is the number of neurons in the :attr:`i`-th hidden layer. + """ + pass + + +PredictorRegistry = ClassRegistry[Predictor]() + + +class _FFNPredictorBase(Predictor, HyperparametersMixin): + """A :class:`_FFNPredictorBase` is the base class for all :class:`Predictor`\s that use an + underlying :class:`SimpleFFN` to map the learned fingerprint to the desired output. + """ + + _T_default_criterion: ChempropMetric + _T_default_metric: ChempropMetric + + def __init__( + self, + n_tasks: int = 1, + input_dim: int = DEFAULT_HIDDEN_DIM, + hidden_dim: int = 300, + n_layers: int = 1, + dropout: float = 0.0, + activation: str = "relu", + criterion: ChempropMetric | None = None, + task_weights: Tensor | None = None, + threshold: float | None = None, + output_transform: UnscaleTransform | None = None, + ): + super().__init__() + # manually add criterion and output_transform to hparams to suppress lightning's warning + # about double saving their state_dict values. + self.save_hyperparameters(ignore=["criterion", "output_transform"]) + self.hparams["criterion"] = criterion + self.hparams["output_transform"] = output_transform + self.hparams["cls"] = self.__class__ + + self.ffn = MLP.build( + input_dim, n_tasks * self.n_targets, hidden_dim, n_layers, dropout, activation + ) + task_weights = torch.ones(n_tasks) if task_weights is None else task_weights + self.criterion = criterion or Factory.build( + self._T_default_criterion, task_weights=task_weights, threshold=threshold + ) + self.output_transform = output_transform if output_transform is not None else nn.Identity() + + @property + def input_dim(self) -> int: + return self.ffn.input_dim + + @property + def output_dim(self) -> int: + return self.ffn.output_dim + + @property + def n_tasks(self) -> int: + return self.output_dim // self.n_targets + + def forward(self, Z: Tensor) -> Tensor: + return self.ffn(Z) + + def encode(self, Z: Tensor, i: int) -> Tensor: + return self.ffn[:i](Z) + + +@PredictorRegistry.register("regression") +class RegressionFFN(_FFNPredictorBase): + n_targets = 1 + _T_default_criterion = MSE + _T_default_metric = MSE + + def forward(self, Z: Tensor) -> Tensor: + return self.output_transform(self.ffn(Z)) + + train_step = forward + + +@PredictorRegistry.register("regression-mve") +class MveFFN(RegressionFFN): + n_targets = 2 + _T_default_criterion = MVELoss + + def forward(self, Z: Tensor) -> Tensor: + Y = self.ffn(Z) + mean, var = torch.chunk(Y, self.n_targets, 1) + var = F.softplus(var) + + mean = self.output_transform(mean) + if not isinstance(self.output_transform, nn.Identity): + var = self.output_transform.transform_variance(var) + + return torch.stack((mean, var), dim=2) + + train_step = forward + + +@PredictorRegistry.register("regression-evidential") +class EvidentialFFN(RegressionFFN): + n_targets = 4 + _T_default_criterion = EvidentialLoss + + def forward(self, Z: Tensor) -> Tensor: + Y = self.ffn(Z) + mean, v, alpha, beta = torch.chunk(Y, self.n_targets, 1) + v = F.softplus(v) + alpha = F.softplus(alpha) + 1 + beta = F.softplus(beta) + + mean = self.output_transform(mean) + if not isinstance(self.output_transform, nn.Identity): + beta = self.output_transform.transform_variance(beta) + + return torch.stack((mean, v, alpha, beta), dim=2) + + train_step = forward + + +@PredictorRegistry.register("regression-quantile") +class QuantileFFN(RegressionFFN): + n_targets = 2 + _T_default_criterion = QuantileLoss + + def forward(self, Z: Tensor) -> Tensor: + Y = super().forward(Z) + lower_bound, upper_bound = torch.chunk(Y, self.n_targets, 1) + + lower_bound = self.output_transform(lower_bound) + upper_bound = self.output_transform(upper_bound) + + mean = (lower_bound + upper_bound) / 2 + interval = upper_bound - lower_bound + + return torch.stack((mean, interval), dim=2) + + train_step = forward + + +class BinaryClassificationFFNBase(_FFNPredictorBase): + pass + + +@PredictorRegistry.register("classification") +class BinaryClassificationFFN(BinaryClassificationFFNBase): + n_targets = 1 + _T_default_criterion = BCELoss + _T_default_metric = BinaryAUROC + + def forward(self, Z: Tensor) -> Tensor: + Y = super().forward(Z) + + return Y.sigmoid() + + def train_step(self, Z: Tensor) -> Tensor: + return super().forward(Z) + + +@PredictorRegistry.register("classification-dirichlet") +class BinaryDirichletFFN(BinaryClassificationFFNBase): + n_targets = 2 + _T_default_criterion = DirichletLoss + _T_default_metric = BinaryAUROC + + def forward(self, Z: Tensor) -> Tensor: + Y = super().forward(Z).reshape(len(Z), -1, 2) + + alpha = F.softplus(Y) + 1 + + u = 2 / alpha.sum(-1) + Y = alpha / alpha.sum(-1, keepdim=True) + + return torch.stack((Y[..., 1], u), dim=2) + + def train_step(self, Z: Tensor) -> Tensor: + Y = super().forward(Z).reshape(len(Z), -1, 2) + + return F.softplus(Y) + 1 + + +@PredictorRegistry.register("multiclass") +class MulticlassClassificationFFN(_FFNPredictorBase): + n_targets = 1 + _T_default_criterion = CrossEntropyLoss + _T_default_metric = MulticlassMCCMetric + + def __init__( + self, + n_classes: int, + n_tasks: int = 1, + input_dim: int = DEFAULT_HIDDEN_DIM, + hidden_dim: int = 300, + n_layers: int = 1, + dropout: float = 0.0, + activation: str = "relu", + criterion: ChempropMetric | None = None, + task_weights: Tensor | None = None, + threshold: float | None = None, + output_transform: UnscaleTransform | None = None, + ): + task_weights = torch.ones(n_tasks) if task_weights is None else task_weights + super().__init__( + n_tasks * n_classes, + input_dim, + hidden_dim, + n_layers, + dropout, + activation, + criterion, + task_weights, + threshold, + output_transform, + ) + + self.n_classes = n_classes + + @property + def n_tasks(self) -> int: + return self.output_dim // (self.n_targets * self.n_classes) + + def forward(self, Z: Tensor) -> Tensor: + return self.train_step(Z).softmax(-1) + + def train_step(self, Z: Tensor) -> Tensor: + return super().forward(Z).reshape(Z.shape[0], -1, self.n_classes) + + +@PredictorRegistry.register("multiclass-dirichlet") +class MulticlassDirichletFFN(MulticlassClassificationFFN): + _T_default_criterion = DirichletLoss + _T_default_metric = MulticlassMCCMetric + + def forward(self, Z: Tensor) -> Tensor: + Y = super().train_step(Z) + + alpha = F.softplus(Y) + 1 + + Y = alpha / alpha.sum(-1, keepdim=True) + + return Y + + def train_step(self, Z: Tensor) -> Tensor: + Y = super().train_step(Z) + + return F.softplus(Y) + 1 + + +class _Exp(nn.Module): + def forward(self, X: Tensor): + return X.exp() + + +@PredictorRegistry.register("spectral") +class SpectralFFN(_FFNPredictorBase): + n_targets = 1 + _T_default_criterion = SID + _T_default_metric = SID + + def __init__(self, *args, spectral_activation: str | None = "softplus", **kwargs): + super().__init__(*args, **kwargs) + + match spectral_activation: + case "exp": + spectral_activation = _Exp() + case "softplus" | None: + spectral_activation = nn.Softplus() + case _: + raise ValueError( + f"Unknown spectral activation: {spectral_activation}. " + "Expected one of 'exp', 'softplus' or None." + ) + + self.ffn.add_module("spectral_activation", spectral_activation) + + def forward(self, Z: Tensor) -> Tensor: + Y = super().forward(Z) + Y = self.ffn.spectral_activation(Y) + return Y / Y.sum(1, keepdim=True) + + train_step = forward diff --git a/chemprop-updated/chemprop/nn/transforms.py b/chemprop-updated/chemprop/nn/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..2af42099aab6409b138316342babd9c209b1b060 --- /dev/null +++ b/chemprop-updated/chemprop/nn/transforms.py @@ -0,0 +1,70 @@ +from numpy.typing import ArrayLike +from sklearn.preprocessing import StandardScaler +import torch +from torch import Tensor, nn + +from chemprop.data.collate import BatchMolGraph + + +class _ScaleTransformMixin(nn.Module): + def __init__(self, mean: ArrayLike, scale: ArrayLike, pad: int = 0): + super().__init__() + + mean = torch.cat([torch.zeros(pad), torch.tensor(mean, dtype=torch.float)]) + scale = torch.cat([torch.ones(pad), torch.tensor(scale, dtype=torch.float)]) + + if mean.shape != scale.shape: + raise ValueError( + f"uneven shapes for 'mean' and 'scale'! got: mean={mean.shape}, scale={scale.shape}" + ) + + self.register_buffer("mean", mean.unsqueeze(0)) + self.register_buffer("scale", scale.unsqueeze(0)) + + @classmethod + def from_standard_scaler(cls, scaler: StandardScaler, pad: int = 0): + return cls(scaler.mean_, scaler.scale_, pad=pad) + + def to_standard_scaler(self, anti_pad: int = 0) -> StandardScaler: + scaler = StandardScaler() + scaler.mean_ = self.mean[anti_pad:].numpy() + scaler.scale_ = self.scale[anti_pad:].numpy() + return scaler + + +class ScaleTransform(_ScaleTransformMixin): + def forward(self, X: Tensor) -> Tensor: + if self.training: + return X + + return (X - self.mean) / self.scale + + +class UnscaleTransform(_ScaleTransformMixin): + def forward(self, X: Tensor) -> Tensor: + if self.training: + return X + + return X * self.scale + self.mean + + def transform_variance(self, var: Tensor) -> Tensor: + if self.training: + return var + return var * (self.scale**2) + + +class GraphTransform(nn.Module): + def __init__(self, V_transform: ScaleTransform, E_transform: ScaleTransform): + super().__init__() + + self.V_transform = V_transform + self.E_transform = E_transform + + def forward(self, bmg: BatchMolGraph) -> BatchMolGraph: + if self.training: + return bmg + + bmg.V = self.V_transform(bmg.V) + bmg.E = self.E_transform(bmg.E) + + return bmg diff --git a/chemprop-updated/chemprop/nn/utils.py b/chemprop-updated/chemprop/nn/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..19913bd0164f385944deb01a70c58fbdb7cd8587 --- /dev/null +++ b/chemprop-updated/chemprop/nn/utils.py @@ -0,0 +1,46 @@ +from enum import auto + +from torch import nn + +from chemprop.utils.utils import EnumMapping + + +class Activation(EnumMapping): + RELU = auto() + LEAKYRELU = auto() + PRELU = auto() + TANH = auto() + SELU = auto() + ELU = auto() + + +def get_activation_function(activation: str | Activation) -> nn.Module: + """Gets an activation function module given the name of the activation. + + See :class:`~chemprop.v2.models.utils.Activation` for available activations. + + Parameters + ---------- + activation : str | Activation + The name of the activation function. + + Returns + ------- + nn.Module + The activation function module. + """ + match Activation.get(activation): + case Activation.RELU: + return nn.ReLU() + case Activation.LEAKYRELU: + return nn.LeakyReLU(0.1) + case Activation.PRELU: + return nn.PReLU() + case Activation.TANH: + return nn.Tanh() + case Activation.SELU: + return nn.SELU() + case Activation.ELU: + return nn.ELU() + case _: + raise RuntimeError("unreachable code reached!") diff --git a/chemprop-updated/chemprop/schedulers.py b/chemprop-updated/chemprop/schedulers.py new file mode 100644 index 0000000000000000000000000000000000000000..843df0f8a75585ea6a309cc1792f93ff15096218 --- /dev/null +++ b/chemprop-updated/chemprop/schedulers.py @@ -0,0 +1,65 @@ +from torch.optim import Optimizer +from torch.optim.lr_scheduler import LambdaLR + + +def build_NoamLike_LRSched( + optimizer: Optimizer, + warmup_steps: int, + cooldown_steps: int, + init_lr: float, + max_lr: float, + final_lr: float, +): + r"""Build a Noam-like learning rate scheduler which schedules the learning rate with a piecewise linear followed + by an exponential decay. + + The learning rate increases linearly from ``init_lr`` to ``max_lr`` over the course of + the first warmup_steps then decreases exponentially to ``final_lr`` over the course of the + remaining ``total_steps - warmup_steps`` (where ``total_steps = total_epochs * steps_per_epoch``). This is roughly based on the learning rate schedule from [1]_, section 5.3. + + Formally, the learning rate schedule is defined as: + + .. math:: + \mathtt{lr}(i) &= + \begin{cases} + \mathtt{init\_lr} + \delta \cdot i &\text{if } i < \mathtt{warmup\_steps} \\ + \mathtt{max\_lr} \cdot \left( \frac{\mathtt{final\_lr}}{\mathtt{max\_lr}} \right)^{\gamma(i)} &\text{otherwise} \\ + \end{cases} + \\ + \delta &\mathrel{:=} + \frac{\mathtt{max\_lr} - \mathtt{init\_lr}}{\mathtt{warmup\_steps}} \\ + \gamma(i) &\mathrel{:=} + \frac{i - \mathtt{warmup\_steps}}{\mathtt{total\_steps} - \mathtt{warmup\_steps}} + + + Parameters + ----------- + optimizer : Optimizer + A PyTorch optimizer. + warmup_steps : int + The number of steps during which to linearly increase the learning rate. + cooldown_steps : int + The number of steps during which to exponential decay the learning rate. + init_lr : float + The initial learning rate. + max_lr : float + The maximum learning rate (achieved after ``warmup_steps``). + final_lr : float + The final learning rate (achieved after ``cooldown_steps``). + + References + ---------- + .. [1] Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, Ł. and Polosukhin, I. "Attention is all you need." Advances in neural information processing systems, 2017, 30. https://arxiv.org/abs/1706.03762 + """ + + def lr_lambda(step: int): + if step < warmup_steps: + warmup_factor = (max_lr - init_lr) / warmup_steps + return step * warmup_factor / init_lr + 1 + elif warmup_steps <= step < warmup_steps + cooldown_steps: + cooldown_factor = (final_lr / max_lr) ** (1 / cooldown_steps) + return (max_lr * (cooldown_factor ** (step - warmup_steps))) / init_lr + else: + return final_lr / init_lr + + return LambdaLR(optimizer, lr_lambda) diff --git a/chemprop-updated/chemprop/train/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44d67a6fc3904d2cdb0108af55f86ed8c7d27c11 Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/cross_validate.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/cross_validate.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e23c8c9cea5c9ab57acbeb920dd00b62b851e1b Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/cross_validate.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/evaluate.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/evaluate.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44c1858923c6b770124edc0bb26cf82fb4bcc080 Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/evaluate.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/loss_functions.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/loss_functions.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fab53738a6550cb4998803b7b83e59bfa8f0dce Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/loss_functions.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/make_predictions.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/make_predictions.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92cfc083cdacf41747c5695c45200be25e7778c3 Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/make_predictions.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/metrics.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/metrics.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c04dcf0fa814e32cd868bc918606271acc6eb81 Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/metrics.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/molecule_fingerprint.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/molecule_fingerprint.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..881b8725854d2168acc4645b07ff5c2dbc4bd27c Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/molecule_fingerprint.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/predict.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/predict.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1749473999e3adbeecc35ca7bc9169595d7ab61e Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/predict.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/run_training.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/run_training.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db47bbf8b1fc85943c5e820bf67488da35d73665 Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/run_training.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/train/__pycache__/train.cpython-37.pyc b/chemprop-updated/chemprop/train/__pycache__/train.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa495ba3f298f0c511b601efc272ee055a9a128b Binary files /dev/null and b/chemprop-updated/chemprop/train/__pycache__/train.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/types.py b/chemprop-updated/chemprop/types.py new file mode 100644 index 0000000000000000000000000000000000000000..71ef27b18cfde9504644f7e627668d5ce62aa431 --- /dev/null +++ b/chemprop-updated/chemprop/types.py @@ -0,0 +1,3 @@ +from rdkit.Chem import Mol + +Rxn = tuple[Mol, Mol] diff --git a/chemprop-updated/chemprop/uncertainty/__init__.py b/chemprop-updated/chemprop/uncertainty/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d81fe53b0575837d843eaca6ae52e4462149c526 --- /dev/null +++ b/chemprop-updated/chemprop/uncertainty/__init__.py @@ -0,0 +1,94 @@ +from .calibrator import ( + AdaptiveMulticlassConformalCalibrator, + BinaryClassificationCalibrator, + CalibratorBase, + IsotonicCalibrator, + IsotonicMulticlassCalibrator, + MulticlassClassificationCalibrator, + MulticlassConformalCalibrator, + MultilabelConformalCalibrator, + MVEWeightingCalibrator, + PlattCalibrator, + RegressionCalibrator, + RegressionConformalCalibrator, + UncertaintyCalibratorRegistry, + ZelikmanCalibrator, + ZScalingCalibrator, +) +from .estimator import ( # RoundRobinSpectraEstimator, + ClassEstimator, + ClassificationDirichletEstimator, + DropoutEstimator, + EnsembleEstimator, + EvidentialAleatoricEstimator, + EvidentialEpistemicEstimator, + EvidentialTotalEstimator, + MulticlassDirichletEstimator, + MVEEstimator, + NoUncertaintyEstimator, + QuantileRegressionEstimator, + UncertaintyEstimator, + UncertaintyEstimatorRegistry, +) +from .evaluator import ( + BinaryClassificationEvaluator, + CalibrationAreaEvaluator, + ExpectedNormalizedErrorEvaluator, + MulticlassClassificationEvaluator, + MulticlassConformalEvaluator, + MultilabelConformalEvaluator, + NLLClassEvaluator, + NLLMulticlassEvaluator, + NLLRegressionEvaluator, + RegressionConformalEvaluator, + RegressionEvaluator, + SpearmanEvaluator, + UncertaintyEvaluatorRegistry, +) + +__all__ = [ + "AdaptiveMulticlassConformalCalibrator", + "BinaryClassificationCalibrator", + "CalibratorBase", + "IsotonicCalibrator", + "IsotonicMulticlassCalibrator", + "MulticlassClassificationCalibrator", + "MulticlassConformalCalibrator", + "MultilabelConformalCalibrator", + "MVEWeightingCalibrator", + "PlattCalibrator", + "RegressionCalibrator", + "RegressionConformalCalibrator", + "UncertaintyCalibratorRegistry", + "ZelikmanCalibrator", + "ZScalingCalibrator", + "BinaryClassificationEvaluator", + "CalibrationAreaEvaluator", + "ExpectedNormalizedErrorEvaluator", + "MulticlassClassificationEvaluator", + "MetricEvaluator", + "MulticlassConformalEvaluator", + "MultilabelConformalEvaluator", + "NLLClassEvaluator", + "NLLMulticlassEvaluator", + "NLLRegressionEvaluator", + "RegressionConformalEvaluator", + "RegressionEvaluator", + "SpearmanEvaluator", + "UncertaintyEvaluator", + "UncertaintyEvaluatorRegistry", + "ClassificationDirichletEstimator", + "ClassEstimator", + "MulticlassDirichletEstimator", + "DropoutEstimator", + "EnsembleEstimator", + "EvidentialAleatoricEstimator", + "EvidentialEpistemicEstimator", + "EvidentialTotalEstimator", + "MVEEstimator", + "NoUncertaintyEstimator", + "QuantileRegressionEstimator", + # "RoundRobinSpectraEstimator", + "UncertaintyEstimator", + "UncertaintyEstimatorRegistry", +] diff --git a/chemprop-updated/chemprop/uncertainty/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/uncertainty/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43c37cc34ec737d54c1d2e08e665efd9c3520b24 Binary files /dev/null and b/chemprop-updated/chemprop/uncertainty/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_calibrator.cpython-37.pyc b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_calibrator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d522582d59e6606267b4e8ba0eff021c66ffb258 Binary files /dev/null and b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_calibrator.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_estimator.cpython-37.pyc b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_estimator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..046c33150f6a21736d768d14293ef97fd8c23a17 Binary files /dev/null and b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_estimator.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_evaluator.cpython-37.pyc b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_evaluator.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc1951758f631ce5e1331991ec1ade98b9244cff Binary files /dev/null and b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_evaluator.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_predictor.cpython-37.pyc b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_predictor.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..131fadea910f50ac7b82fa340cf6cbcf39d0c981 Binary files /dev/null and b/chemprop-updated/chemprop/uncertainty/__pycache__/uncertainty_predictor.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/uncertainty/calibrator.py b/chemprop-updated/chemprop/uncertainty/calibrator.py new file mode 100644 index 0000000000000000000000000000000000000000..6c9769e7aed4bfaf1e54030932be11fb54304171 --- /dev/null +++ b/chemprop-updated/chemprop/uncertainty/calibrator.py @@ -0,0 +1,715 @@ +from abc import ABC, abstractmethod +import logging +import math +from typing import Self + +import numpy as np +from scipy.optimize import fmin +from scipy.special import expit, logit, softmax +from sklearn.isotonic import IsotonicRegression +import torch +from torch import Tensor + +from chemprop.utils.registry import ClassRegistry + +logger = logging.getLogger(__name__) + + +class CalibratorBase(ABC): + """ + A base class for calibrating the predicted uncertainties. + """ + + @abstractmethod + def fit(self, *args, **kwargs) -> Self: + """ + Fit calibration method for the calibration data. + """ + + @abstractmethod + def apply(self, uncs: Tensor) -> Tensor: + """ + Apply this calibrator to the input uncertainties. + + Parameters + ---------- + uncs: Tensor + a tensor containinig uncalibrated uncertainties + + Returns + ------- + Tensor + the calibrated uncertainties + """ + + +UncertaintyCalibratorRegistry = ClassRegistry[CalibratorBase]() + + +class RegressionCalibrator(CalibratorBase): + """ + A class for calibrating the predicted uncertainties in regressions tasks. + """ + + @abstractmethod + def fit(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + """ + Fit calibration method for the calibration data. + + Parameters + ---------- + preds: Tensor + the predictions for regression tasks. It is a tensor of the shape of ``n x t``, where ``n`` is + the number of input molecules/reactions, and ``t`` is the number of tasks. + uncs: Tensor + the predicted uncertainties of the shape of ``n x t`` + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the fitting + + Returns + ------- + self : RegressionCalibrator + the fitted calibrator + """ + + +@UncertaintyCalibratorRegistry.register("zscaling") +class ZScalingCalibrator(RegressionCalibrator): + """Calibrate regression datasets by applying a scaling value to the uncalibrated standard deviation, + fitted by minimizing the negative-log-likelihood of a normal distribution around each prediction. [levi2022]_ + + References + ---------- + .. [levi2022] Levi, D.; Gispan, L.; Giladi, N.; Fetaya, E. "Evaluating and Calibrating Uncertainty Prediction in + Regression Tasks." Sensors, 2022, 22(15), 5540. https://www.mdpi.com/1424-8220/22/15/5540 + """ + + def fit(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + scalings = np.zeros(uncs.shape[1]) + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + preds_j = preds[:, j][mask_j].numpy() + uncs_j = uncs[:, j][mask_j].numpy() + targets_j = targets[:, j][mask_j].numpy() + errors = preds_j - targets_j + + def objective(scaler_value: float): + scaled_vars = uncs_j * scaler_value**2 + nll = np.log(2 * np.pi * scaled_vars) / 2 + errors**2 / (2 * scaled_vars) + return nll.sum() + + zscore = errors / np.sqrt(uncs_j) + initial_guess = np.std(zscore) + scalings[j] = fmin(objective, x0=initial_guess, disp=False) + + self.scalings = torch.tensor(scalings) + return self + + def apply(self, uncs: Tensor) -> Tensor: + return uncs * self.scalings**2 + + +@UncertaintyCalibratorRegistry.register("zelikman-interval") +class ZelikmanCalibrator(RegressionCalibrator): + """Calibrate regression datasets using a method that does not depend on a particular probability function form. + + It uses the "CRUDE" method as described in [zelikman2020]_. We implemented this method to be used with variance as the uncertainty. + + Parameters + ---------- + p: float + The target qunatile, :math:`p \in [0, 1]` + + References + ---------- + .. [zelikman2020] Zelikman, E.; Healy, C.; Zhou, S.; Avati, A. "CRUDE: calibrating regression uncertainty distributions + empirically." arXiv preprint arXiv:2005.12496. https://doi.org/10.48550/arXiv.2005.12496 + """ + + def __init__(self, p: float): + super().__init__() + self.p = p + if not 0 <= self.p <= 1: + raise ValueError(f"arg `p` must be between 0 and 1. got: {p}.") + + def fit(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + scalings = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + preds_j = preds[:, j][mask_j] + uncs_j = uncs[:, j][mask_j] + targets_j = targets[:, j][mask_j] + z = (preds_j - targets_j).abs() / (uncs_j).sqrt() + scaling = torch.quantile(z, self.p, interpolation="lower") + scalings.append(scaling) + + self.scalings = torch.tensor(scalings) + return self + + def apply(self, uncs: Tensor) -> Tensor: + return uncs * self.scalings**2 + + +@UncertaintyCalibratorRegistry.register("mve-weighting") +class MVEWeightingCalibrator(RegressionCalibrator): + """Calibrate regression datasets that have ensembles of individual models that make variance predictions. + + This method minimizes the negative log likelihood for the predictions versus the targets by applying + a weighted average across the variance predictions of the ensemble. [wang2021]_ + + References + ---------- + .. [wang2021] Wang, D.; Yu, J.; Chen, L.; Li, X.; Jiang, H.; Chen, K.; Zheng, M.; Luo, X. "A hybrid framework + for improving uncertainty quantification in deep learning-based QSAR regression modeling." J. Cheminform., + 2021, 13, 1-17. https://doi.org/10.1186/s13321-021-00551-x + """ + + def fit(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + """ + Fit calibration method for the calibration data. + + Parameters + ---------- + preds: Tensor + the predictions for regression tasks. It is a tensor of the shape of ``n x t``, where ``n`` is + the number of input molecules/reactions, and ``t`` is the number of tasks. + uncs: Tensor + the predicted uncertainties of the shape of ``m x n x t`` + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the fitting + + Returns + ------- + self : MVEWeightingCalibrator + the fitted calibrator + """ + scalings = [] + for j in range(uncs.shape[2]): + mask_j = mask[:, j] + preds_j = preds[:, j][mask_j].numpy() + uncs_j = uncs[:, mask_j, j].numpy() + targets_j = targets[:, j][mask_j].numpy() + errors = preds_j - targets_j + + def objective(scaler_values: np.ndarray): + scaler_values = np.reshape(softmax(scaler_values), [-1, 1]) # (m, 1) + scaled_vars = np.sum(uncs_j * scaler_values, axis=0, keepdims=False) + nll = np.log(2 * np.pi * scaled_vars) / 2 + errors**2 / (2 * scaled_vars) + return np.sum(nll) + + initial_guess = np.ones(uncs_j.shape[0]) + sol = fmin(objective, x0=initial_guess, disp=False) + scalings.append(torch.tensor(softmax(sol))) + + self.scalings = torch.stack(scalings).t().unsqueeze(1) + return self + + def apply(self, uncs: Tensor) -> Tensor: + """ + Apply this calibrator to the input uncertainties. + + Parameters + ---------- + uncs: Tensor + a tensor containinig uncalibrated uncertainties of the shape of ``m x n x t`` + + Returns + ------- + Tensor + the calibrated uncertainties of the shape of ``n x t`` + """ + return (uncs * self.scalings).sum(0) + + +@UncertaintyCalibratorRegistry.register("conformal-regression") +class RegressionConformalCalibrator(RegressionCalibrator): + r"""Conformalize quantiles to make the interval :math:`[\hat{t}_{\alpha/2}(x),\hat{t}_{1-\alpha/2}(x)]` to have + approximately :math:`1-\alpha` coverage. [angelopoulos2021]_ + + .. math:: + s(x, y) &= \max \left\{ \hat{t}_{\alpha/2}(x) - y, y - \hat{t}_{1-\alpha/2}(x) \right\} + + \hat{q} &= Q(s_1, \ldots, s_n; \left\lceil \frac{(n+1)(1-\alpha)}{n} \right\rceil) + + C(x) &= \left[ \hat{t}_{\alpha/2}(x) - \hat{q}, \hat{t}_{1-\alpha/2}(x) + \hat{q} \right] + + where :math:`s` is the nonconformity score as the difference between :math:`y` and its nearest quantile. + :math:`\hat{t}_{\alpha/2}(x)` and :math:`\hat{t}_{1-\alpha/2}(x)` are the predicted quantiles from a quantile + regression model. + + .. note:: + The algorithm is specifically designed for quantile regression model. Intuitively, the set :math:`C(x)` just + grows or shrinks the distance between the quantiles by :math:`\hat{q}` to achieve coverage. However, this + function can also be applied to regression model without quantiles being provided. In this case, both + :math:`\hat{t}_{\alpha/2}(x)` and :math:`\hat{t}_{1-\alpha/2}(x)` are the same as :math:`\hat{y}`. Then, the + interval would be the same for every data point (i.e., :math:`\left[-\hat{q}, \hat{q} \right]`). + + Parameters + ---------- + alpha: float + The error rate, :math:`\alpha \in [0, 1]` + + References + ---------- + .. [angelopoulos2021] Angelopoulos, A.N.; Bates, S.; "A Gentle Introduction to Conformal Prediction and Distribution-Free + Uncertainty Quantification." arXiv Preprint 2021, https://arxiv.org/abs/2107.07511 + """ + + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + self.bounds = torch.tensor([-1 / 2, 1 / 2]).view(-1, 1) + if not 0 <= self.alpha <= 1: + raise ValueError(f"arg `alpha` must be between 0 and 1. got: {alpha}.") + + def fit(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + self.qhats = [] + for j in range(preds.shape[1]): + mask_j = mask[:, j] + targets_j = targets[:, j][mask_j] + preds_j = preds[:, j][mask_j] + interval_j = uncs[:, j][mask_j] + + interval_bounds = self.bounds * interval_j.unsqueeze(0) + pred_bounds = preds_j.unsqueeze(0) + interval_bounds + + calibration_scores = torch.max(pred_bounds[0] - targets_j, targets_j - pred_bounds[1]) + + num_data = targets_j.shape[0] + if self.alpha >= 1 / (num_data + 1): + q_level = math.ceil((num_data + 1) * (1 - self.alpha)) / num_data + else: + q_level = 1 + logger.warning( + "The error rate (i.e., `alpha`) is smaller than `1 / (number of data + 1)`, so the `1 - alpha` quantile is set to 1, " + "but this only ensures that the coverage is trivially satisfied." + ) + qhat = torch.quantile(calibration_scores, q_level, interpolation="higher") + self.qhats.append(qhat) + + self.qhats = torch.tensor(self.qhats) + return self + + def apply(self, uncs: Tensor) -> tuple[Tensor, Tensor]: + """ + Apply this calibrator to the input uncertainties. + + Parameters + ---------- + uncs: Tensor + a tensor containinig uncalibrated uncertainties + + Returns + ------- + Tensor + the calibrated intervals + """ + cal_intervals = uncs + 2 * self.qhats + + return cal_intervals + + +class BinaryClassificationCalibrator(CalibratorBase): + """ + A class for calibrating the predicted uncertainties in binary classification tasks. + """ + + @abstractmethod + def fit(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + """ + Fit calibration method for the calibration data. + + Parameters + ---------- + uncs: Tensor + the predicted uncertainties (i.e., the predicted probability of class 1) of the shape of ``n x t``, where ``n`` is the number of input + molecules/reactions, and ``t`` is the number of tasks. + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the fitting + + Returns + ------- + self : BinaryClassificationCalibrator + the fitted calibrator + """ + + +@UncertaintyCalibratorRegistry.register("platt") +class PlattCalibrator(BinaryClassificationCalibrator): + """Calibrate classification datasets using the Platt scaling algorithm [guo2017]_, [platt1999]_. + + In [platt1999]_, Platt suggests using the number of positive and negative training examples to + adjust the value of target probabilities used to fit the parameters. + + References + ---------- + .. [guo2017] Guo, C.; Pleiss, G.; Sun, Y.; Weinberger, K. Q. "On calibration of modern neural + networks". ICML, 2017. https://arxiv.org/abs/1706.04599 + .. [platt1999] Platt, J. "Probabilistic Outputs for Support Vector Machines and Comparisons to + Regularized Likelihood Methods." Adv. Large Margin Classif. 1999, 10 (3), 61–74. + """ + + def fit( + self, uncs: Tensor, targets: Tensor, mask: Tensor, training_targets: Tensor | None = None + ) -> Self: + if torch.any((targets[mask] != 0) & (targets[mask] != 1)): + raise ValueError( + "Platt scaling is only implemented for binary classification tasks! Input tensor " + "must contain only 0's and 1's." + ) + + if training_targets is not None: + logger.info( + "Training targets were provided. Platt scaling for calibration uses a Bayesian " + "correction to avoid training set overfitting. Now replacing calibration targets " + "[0, 1] with adjusted values." + ) + + n_negative_examples = (training_targets == 0).sum(dim=0) + n_positive_examples = (training_targets == 1).sum(dim=0) + + negative_target_bayes_MAP = (1 / (n_negative_examples + 2)).expand_as(targets) + positive_target_bayes_MAP = ( + (n_positive_examples + 1) / (n_positive_examples + 2) + ).expand_as(targets) + + targets = targets.float() + targets[targets == 0] = negative_target_bayes_MAP[targets == 0] + targets[targets == 1] = positive_target_bayes_MAP[targets == 1] + else: + logger.info("No training targets were provided. No Bayesian correction is applied.") + + xs = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + uncs_j = uncs[:, j][mask_j].numpy() + targets_j = targets[:, j][mask_j].numpy() + + def objective(parameters): + a, b = parameters + scaled_uncs = expit(a * logit(uncs_j) + b) + nll = -1 * np.sum( + targets_j * np.log(scaled_uncs) + (1 - targets_j) * np.log(1 - scaled_uncs) + ) + return nll + + xs.append(fmin(objective, x0=[1, 0], disp=False)) + + xs = np.vstack(xs) + self.a, self.b = torch.tensor(xs).T.unbind(dim=0) + + return self + + def apply(self, uncs: Tensor) -> Tensor: + return torch.sigmoid(self.a * torch.logit(uncs) + self.b) + + +@UncertaintyCalibratorRegistry.register("isotonic") +class IsotonicCalibrator(BinaryClassificationCalibrator): + """Calibrate binary classification datasets using isotonic regression as discussed in [guo2017]_. + In effect, the method transforms incoming uncalibrated confidences using a histogram-like + function where the range of each transforming bin and its magnitude is learned. + + References + ---------- + .. [guo2017] Guo, C.; Pleiss, G.; Sun, Y.; Weinberger, K. Q. "On calibration of modern neural + networks". ICML, 2017. https://arxiv.org/abs/1706.04599 + """ + + def fit(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + if torch.any((targets[mask] != 0) & (targets[mask] != 1)): + raise ValueError( + "Isotonic calibration is only implemented for binary classification tasks! Input " + "tensor must contain only 0's and 1's." + ) + + isotonic_models = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + uncs_j = uncs[:, j][mask_j].numpy() + targets_j = targets[:, j][mask_j].numpy() + + isotonic_model = IsotonicRegression(y_min=0, y_max=1, out_of_bounds="clip") + isotonic_model.fit(uncs_j, targets_j) + isotonic_models.append(isotonic_model) + + self.isotonic_models = isotonic_models + + return self + + def apply(self, uncs: Tensor) -> Tensor: + cal_uncs = [] + for j, isotonic_model in enumerate(self.isotonic_models): + cal_uncs.append(isotonic_model.predict(uncs[:, j].numpy())) + return torch.tensor(np.array(cal_uncs)).t() + + +@UncertaintyCalibratorRegistry.register("conformal-multilabel") +class MultilabelConformalCalibrator(BinaryClassificationCalibrator): + r"""Creates conformal in-set and conformal out-set such that, for :math:`1-\alpha` proportion of datapoints, + the set of labels is bounded by the in- and out-sets [1]_: + + .. math:: + \Pr \left( + \hat{\mathcal C}_{\text{in}}(X) \subseteq \mathcal Y \subseteq \hat{\mathcal C}_{\text{out}}(X) + \right) \geq 1 - \alpha, + + where the in-set :math:`\hat{\mathcal C}_\text{in}` is contained by the set of true labels :math:`\mathcal Y` and + :math:`\mathcal Y` is contained within the out-set :math:`\hat{\mathcal C}_\text{out}`. + + Parameters + ---------- + alpha: float + The error rate, :math:`\alpha \in [0, 1]` + + References + ---------- + .. [1] Cauchois, M.; Gupta, S.; Duchi, J.; "Knowing What You Know: Valid and Validated Confidence Sets + in Multiclass and Multilabel Prediction." arXiv Preprint 2020, https://arxiv.org/abs/2004.10181 + """ + + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + if not 0 <= self.alpha <= 1: + raise ValueError(f"arg `alpha` must be between 0 and 1. got: {alpha}.") + + @staticmethod + def nonconformity_scores(preds: Tensor): + r""" + Compute nonconformity score as the negative of the predicted probability. + + .. math:: + s_i = -\hat{f}(X_i)_{Y_i} + """ + return -preds + + def fit(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + if targets.shape[1] < 2: + raise ValueError( + f"the number of tasks should be larger than 1! got: {targets.shape[1]}." + ) + + has_zeros = torch.any(targets == 0, dim=1) + index_zeros = targets[has_zeros] == 0 + scores_in = self.nonconformity_scores(uncs[has_zeros]) + masked_scores_in = scores_in * index_zeros.float() + torch.where( + index_zeros, torch.zeros_like(scores_in), torch.tensor(float("inf")) + ) + calibration_scores_in = torch.min( + masked_scores_in.masked_fill(~mask, float("inf")), dim=1 + ).values + + has_ones = torch.any(targets == 1, dim=1) + index_ones = targets[has_ones] == 1 + scores_out = self.nonconformity_scores(uncs[has_ones]) + masked_scores_out = scores_out * index_ones.float() + torch.where( + index_ones, torch.zeros_like(scores_out), torch.tensor(float("-inf")) + ) + calibration_scores_out = torch.max( + masked_scores_out.masked_fill(~mask, float("-inf")), dim=1 + ).values + + self.tout = torch.quantile( + calibration_scores_out, 1 - self.alpha / 2, interpolation="higher" + ) + self.tin = torch.quantile(calibration_scores_in, self.alpha / 2, interpolation="higher") + return self + + def apply(self, uncs: Tensor) -> Tensor: + """ + Apply this calibrator to the input uncertainties. + + Parameters + ---------- + uncs: Tensor + a tensor containinig uncalibrated uncertainties + + Returns + ------- + Tensor + the calibrated uncertainties of the shape of ``n x t x 2``, where ``n`` is the number of input + molecules/reactions, ``t`` is the number of tasks, and the first element in the last dimension + corresponds to the in-set :math:`\hat{\mathcal C}_\text{in}`, while the second corresponds to + the out-set :math:`\hat{\mathcal C}_\text{out}`. + """ + scores = self.nonconformity_scores(uncs) + + cal_preds_in = (scores <= self.tin).int() + cal_preds_out = (scores <= self.tout).int() + cal_preds_in_out = torch.stack((cal_preds_in, cal_preds_out), dim=2) + + return cal_preds_in_out + + +class MulticlassClassificationCalibrator(CalibratorBase): + """ + A class for calibrating the predicted uncertainties in multiclass classification tasks. + """ + + @abstractmethod + def fit(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + """ + Fit calibration method for the calibration data. + + Parameters + ---------- + uncs: Tensor + the predicted uncertainties (i.e., the predicted probabilities for each class) of the + shape of ``n x t x c``, where ``n`` is the number of input molecules/reactions, ``t`` is + the number of tasks, and ``c`` is the number of classes. + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in + the fitting + + Returns + ------- + self : MulticlassClassificationCalibrator + the fitted calibrator + """ + + +@UncertaintyCalibratorRegistry.register("conformal-multiclass") +class MulticlassConformalCalibrator(MulticlassClassificationCalibrator): + r"""Create a prediction sets of possible labels :math:`C(X_{\text{test}}) \subset \{1 \mathrel{.\,.} K\}` that follows: + + .. math:: + 1 - \alpha \leq \Pr (Y_{\text{test}} \in C(X_{\text{test}})) \leq 1 - \alpha + \frac{1}{n + 1} + + In other words, the probability that the prediction set contains the correct label is almost exactly :math:`1-\alpha`. + More detailes can be found in [1]_. + + Parameters + ---------- + alpha: float + Error rate, :math:`\alpha \in [0, 1]` + + References + ---------- + .. [1] Angelopoulos, A.N.; Bates, S.; "A Gentle Introduction to Conformal Prediction and Distribution-Free + Uncertainty Quantification." arXiv Preprint 2021, https://arxiv.org/abs/2107.07511 + """ + + def __init__(self, alpha: float): + super().__init__() + self.alpha = alpha + if not 0 <= self.alpha <= 1: + raise ValueError(f"arg `alpha` must be between 0 and 1. got: {alpha}.") + + @staticmethod + def nonconformity_scores(preds: Tensor): + r"""Compute nonconformity score as the negative of the softmax output for the true class. + + .. math:: + s_i = -\hat{f}(X_i)_{Y_i} + """ + return -preds + + def fit(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + self.qhats = [] + scores = self.nonconformity_scores(uncs) + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + targets_j = targets[:, j][mask_j] + scores_j = scores[:, j][mask_j] + + scores_j = torch.gather(scores_j, 1, targets_j.unsqueeze(1)).squeeze(1) + num_data = targets_j.shape[0] + if self.alpha >= 1 / (num_data + 1): + q_level = math.ceil((num_data + 1) * (1 - self.alpha)) / num_data + else: + q_level = 1 + logger.warning( + "`alpha` is smaller than `1 / (number of data + 1)`, so the `1 - alpha` quantile is set to 1, " + "but this only ensures that the coverage is trivially satisfied." + ) + qhat = torch.quantile(scores_j, q_level, interpolation="higher") + self.qhats.append(qhat) + + self.qhats = torch.tensor(self.qhats) + return self + + def apply(self, uncs: Tensor) -> Tensor: + calibrated_preds = torch.zeros_like(uncs, dtype=torch.int) + scores = self.nonconformity_scores(uncs) + + for j, qhat in enumerate(self.qhats): + calibrated_preds[:, j] = (scores[:, j] <= qhat).int() + + return calibrated_preds + + +@UncertaintyCalibratorRegistry.register("conformal-adaptive") +class AdaptiveMulticlassConformalCalibrator(MulticlassConformalCalibrator): + @staticmethod + def nonconformity_scores(preds): + r"""Compute nonconformity score by greedily including classes in the classification set until it reaches the true label. + + .. math:: + s(x, y) = \sum_{j=1}^{k} \hat{f}(x)_{\pi_j(x)}, \text{ where } y = \pi_k(x) + + where :math:`\pi_k(x)` is the permutation of :math:`\{1 \mathrel{.\,.} K\}` that sorts :math:`\hat{f}(X_{test})` from most likely to least likely. + """ + + sort_index = torch.argsort(-preds, dim=2) + sorted_preds = torch.gather(preds, 2, sort_index) + sorted_scores = sorted_preds.cumsum(dim=2) + unsorted_scores = torch.zeros_like(sorted_scores).scatter_(2, sort_index, sorted_scores) + + return unsorted_scores + + +@UncertaintyCalibratorRegistry.register("isotonic-multiclass") +class IsotonicMulticlassCalibrator(MulticlassClassificationCalibrator): + """Calibrate multiclass classification datasets using isotonic regression as discussed in + [guo2017]_. It uses a one-vs-all aggregation scheme to extend isotonic regression from binary to + multiclass classifiers. + + References + ---------- + .. [guo2017] Guo, C.; Pleiss, G.; Sun, Y.; Weinberger, K. Q. "On calibration of modern neural + networks". ICML, 2017. https://arxiv.org/abs/1706.04599 + """ + + def fit(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Self: + isotonic_models = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + uncs_j = uncs[:, j, :][mask_j].numpy() + targets_j = targets[:, j][mask_j].numpy() + + class_isotonic_models = [] + for k in range(uncs.shape[2]): + class_uncs_j = uncs_j[..., k] + positive_class_targets = targets_j == k + + class_targets = np.ones_like(class_uncs_j) + class_targets[positive_class_targets] = 1 + class_targets[~positive_class_targets] = 0 + + isotonic_model = IsotonicRegression(y_min=0, y_max=1, out_of_bounds="clip") + isotonic_model.fit(class_uncs_j, class_targets) + class_isotonic_models.append(isotonic_model) + + isotonic_models.append(class_isotonic_models) + + self.isotonic_models = isotonic_models + + return self + + def apply(self, uncs: Tensor) -> Tensor: + cal_uncs = torch.zeros_like(uncs) + for j, class_isotonic_models in enumerate(self.isotonic_models): + for k, isotonic_model in enumerate(class_isotonic_models): + class_uncs_j = uncs[:, j, k].numpy() + class_cal_uncs = isotonic_model.predict(class_uncs_j) + cal_uncs[:, j, k] = torch.tensor(class_cal_uncs) + return cal_uncs / cal_uncs.sum(dim=-1, keepdim=True) diff --git a/chemprop-updated/chemprop/uncertainty/estimator.py b/chemprop-updated/chemprop/uncertainty/estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..95269ac62896ab9ae1ad34ba6053777bc9a1a207 --- /dev/null +++ b/chemprop-updated/chemprop/uncertainty/estimator.py @@ -0,0 +1,376 @@ +from abc import ABC, abstractmethod +from typing import Iterable + +from lightning import pytorch as pl +import torch +from torch import Tensor +from torch.utils.data import DataLoader + +from chemprop.models.model import MPNN +from chemprop.utils.registry import ClassRegistry + + +class UncertaintyEstimator(ABC): + """A helper class for making model predictions and associated uncertainty predictions.""" + + @abstractmethod + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + """ + Calculate the uncalibrated predictions and uncertainties for the dataloader. + + dataloader: DataLoader + the dataloader used for model predictions and uncertainty predictions + models: Iterable[MPNN] + the models used for model predictions and uncertainty predictions + trainer: pl.Trainer + an instance of the :class:`~lightning.pytorch.trainer.trainer.Trainer` used to manage model inference + + Returns + ------- + preds : Tensor + the model predictions, with shape varying by task type: + + * regression/binary classification: ``m x n x t`` + + * multiclass classification: ``m x n x t x c``, where ``m`` is the number of models, + ``n`` is the number of inputs, ``t`` is the number of tasks, and ``c`` is the number of classes. + uncs : Tensor + the predicted uncertainties, with shapes of ``m' x n x t``. + + .. note:: + The ``m`` and ``m'`` are different by definition. The ``m`` is the number of models, + while the ``m'`` is the number of uncertainty estimations. For example, if two MVE + or evidential models are provided, both ``m`` and ``m'`` are two. However, for an + ensemble of two models, ``m'`` would be one (even though ``m = 2``). + """ + + +UncertaintyEstimatorRegistry = ClassRegistry[UncertaintyEstimator]() + + +@UncertaintyEstimatorRegistry.register("none") +class NoUncertaintyEstimator(UncertaintyEstimator): + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + predss = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + predss.append(preds) + return torch.stack(predss), None + + +@UncertaintyEstimatorRegistry.register("mve") +class MVEEstimator(UncertaintyEstimator): + """ + Class that estimates prediction means and variances (MVE). [nix1994]_ + + References + ---------- + .. [nix1994] Nix, D. A.; Weigend, A. S. "Estimating the mean and variance of the target + probability distribution." Proceedings of 1994 IEEE International Conference on Neural + Networks, 1994 https://doi.org/10.1109/icnn.1994.374138 + """ + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + mves = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + mves.append(preds) + mves = torch.stack(mves, dim=0) + mean, var = mves.unbind(dim=-1) + return mean, var + + +@UncertaintyEstimatorRegistry.register("ensemble") +class EnsembleEstimator(UncertaintyEstimator): + """ + Class that predicts the uncertainty of predictions based on the variance in predictions among + an ensemble's submodels. + """ + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + if len(models) <= 1: + raise ValueError( + "Ensemble method for uncertainty is only available when multiple models are provided." + ) + ensemble_preds = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + ensemble_preds.append(preds) + stacked_preds = torch.stack(ensemble_preds).float() + vars = torch.var(stacked_preds, dim=0, correction=0).unsqueeze(0) + return stacked_preds, vars + + +@UncertaintyEstimatorRegistry.register("classification") +class ClassEstimator(UncertaintyEstimator): + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + predss = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + predss.append(preds) + return torch.stack(predss), torch.stack(predss) + + +@UncertaintyEstimatorRegistry.register("evidential-total") +class EvidentialTotalEstimator(UncertaintyEstimator): + """ + Class that predicts the total evidential uncertainty based on hyperparameters of + the evidential distribution [amini2020]_. + + References + ----------- + .. [amini2020] Amini, A.; Schwarting, W.; Soleimany, A.; Rus, D. "Deep Evidential Regression". + NeurIPS, 2020. https://arxiv.org/abs/1910.02600 + """ + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + uncs = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + uncs.append(preds) + uncs = torch.stack(uncs) + mean, v, alpha, beta = uncs.unbind(-1) + total_uncs = (1 + 1 / v) * (beta / (alpha - 1)) + return mean, total_uncs + + +@UncertaintyEstimatorRegistry.register("evidential-epistemic") +class EvidentialEpistemicEstimator(UncertaintyEstimator): + """ + Class that predicts the epistemic evidential uncertainty based on hyperparameters of + the evidential distribution. + """ + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + uncs = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + uncs.append(preds) + uncs = torch.stack(uncs) + mean, v, alpha, beta = uncs.unbind(-1) + epistemic_uncs = (1 / v) * (beta / (alpha - 1)) + return mean, epistemic_uncs + + +@UncertaintyEstimatorRegistry.register("evidential-aleatoric") +class EvidentialAleatoricEstimator(UncertaintyEstimator): + """ + Class that predicts the aleatoric evidential uncertainty based on hyperparameters of + the evidential distribution. + """ + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + uncs = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + uncs.append(preds) + uncs = torch.stack(uncs) + mean, _, alpha, beta = uncs.unbind(-1) + aleatoric_uncs = beta / (alpha - 1) + return mean, aleatoric_uncs + + +@UncertaintyEstimatorRegistry.register("dropout") +class DropoutEstimator(UncertaintyEstimator): + """ + A :class:`DropoutEstimator` creates a virtual ensemble of models via Monte Carlo dropout with + the provided model [gal2016]_. + + Parameters + ---------- + ensemble_size: int + The number of samples to draw for the ensemble. + dropout: float | None + The probability of dropping out units in the dropout layers. If unspecified, + the training probability is used, which is prefered but not possible if the model was not + trained with dropout (i.e. p=0). + + References + ----------- + .. [gal2016] Gal, Y.; Ghahramani, Z. "Dropout as a bayesian approximation: Representing model uncertainty in deep learning." + International conference on machine learning. PMLR, 2016. https://arxiv.org/abs/1506.02142 + """ + + def __init__(self, ensemble_size: int, dropout: None | float = None): + self.ensemble_size = ensemble_size + self.dropout = dropout + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + meanss, varss = [], [] + for model in models: + self._setup_model(model) + individual_preds = [] + + for _ in range(self.ensemble_size): + predss = trainer.predict(model, dataloader) + preds = torch.concat(predss, 0) + individual_preds.append(preds) + + stacked_preds = torch.stack(individual_preds, dim=0).float() + means = torch.mean(stacked_preds, dim=0) + vars = torch.var(stacked_preds, dim=0, correction=0) + self._restore_model(model) + meanss.append(means) + varss.append(vars) + return torch.stack(meanss), torch.stack(varss) + + def _setup_model(self, model): + model._predict_step = model.predict_step + model.predict_step = self._predict_step(model) + model.apply(self._change_dropout) + + def _restore_model(self, model): + model.predict_step = model._predict_step + del model._predict_step + model.apply(self._restore_dropout) + + def _predict_step(self, model): + def _wrapped_predict_step(*args, **kwargs): + model.apply(self._activate_dropout) + return model._predict_step(*args, **kwargs) + + return _wrapped_predict_step + + def _activate_dropout(self, module): + if isinstance(module, torch.nn.Dropout): + module.train() + + def _change_dropout(self, module): + if isinstance(module, torch.nn.Dropout): + module._p = module.p + if self.dropout: + module.p = self.dropout + + def _restore_dropout(self, module): + if isinstance(module, torch.nn.Dropout): + if hasattr(module, "_p"): + module.p = module._p + del module._p + + +# TODO: Add in v2.1.x +# @UncertaintyEstimatorRegistry.register("spectra-roundrobin") +# class RoundRobinSpectraEstimator(UncertaintyEstimator): +# def __call__( +# self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer +# ) -> tuple[Tensor, Tensor]: +# return + + +@UncertaintyEstimatorRegistry.register("classification-dirichlet") +class ClassificationDirichletEstimator(UncertaintyEstimator): + """ + A :class:`ClassificationDirichletEstimator` predicts an amount of 'evidence' for both the + negative class and the positive class as described in [sensoy2018]_. The class probabilities and + the uncertainty are calculated based on the evidence. + + .. math:: + S = \sum_{i=1}^K \alpha_i + p_i = \alpha_i / S + u = K / S + + where :math:`K` is the number of classes, :math:`\alpha_i` is the evidence for class :math:`i`, + :math:`p_i` is the probability of class :math:`i`, and :math:`u` is the uncertainty. + + References + ---------- + .. [sensoy2018] Sensoy, M.; Kaplan, L.; Kandemir, M. "Evidential deep learning to quantify + classification uncertainty." NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768 + """ + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + uncs = [] + for model in models: + preds = torch.concat(trainer.predict(model, dataloader), 0) + uncs.append(preds) + uncs = torch.stack(uncs, dim=0) + y, u = uncs.unbind(dim=-1) + return y, u + + +@UncertaintyEstimatorRegistry.register("multiclass-dirichlet") +class MulticlassDirichletEstimator(UncertaintyEstimator): + """ + A :class:`MulticlassDirichletEstimator` predicts an amount of 'evidence' for each class as + described in [sensoy2018]_. The class probabilities and the uncertainty are calculated based on + the evidence. + + .. math:: + S = \sum_{i=1}^K \alpha_i + p_i = \alpha_i / S + u = K / S + + where :math:`K` is the number of classes, :math:`\alpha_i` is the evidence for class :math:`i`, + :math:`p_i` is the probability of class :math:`i`, and :math:`u` is the uncertainty. + + References + ---------- + .. [sensoy2018] Sensoy, M.; Kaplan, L.; Kandemir, M. "Evidential deep learning to quantify + classification uncertainty." NeurIPS, 2018, 31. https://doi.org/10.48550/arXiv.1806.01768 + """ + + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + preds = [] + uncs = [] + for model in models: + self._setup_model(model) + output = torch.concat(trainer.predict(model, dataloader), 0) + self._restore_model(model) + preds.append(output[..., :-1]) + uncs.append(output[..., -1]) + preds = torch.stack(preds, 0) + uncs = torch.stack(uncs, 0) + + return preds, uncs + + def _setup_model(self, model): + model.predictor._forward = model.predictor.forward + model.predictor.forward = self._forward.__get__(model.predictor, model.predictor.__class__) + + def _restore_model(self, model): + model.predictor.forward = model.predictor._forward + del model.predictor._forward + + def _forward(self, Z: Tensor) -> Tensor: + alpha = self.train_step(Z) + + u = alpha.shape[2] / alpha.sum(-1, keepdim=True) + Y = alpha / alpha.sum(-1, keepdim=True) + + return torch.concat([Y, u], -1) + + +@UncertaintyEstimatorRegistry.register("quantile-regression") +class QuantileRegressionEstimator(UncertaintyEstimator): + def __call__( + self, dataloader: DataLoader, models: Iterable[MPNN], trainer: pl.Trainer + ) -> tuple[Tensor, Tensor]: + individual_preds = [] + for model in models: + predss = trainer.predict(model, dataloader) + individual_preds.append(torch.concat(predss, 0)) + stacked_preds = torch.stack(individual_preds).float() + mean, interval = stacked_preds.unbind(2) + return mean, interval diff --git a/chemprop-updated/chemprop/uncertainty/evaluator.py b/chemprop-updated/chemprop/uncertainty/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..1e88fab2835b0b29aadd654949176b38ff899476 --- /dev/null +++ b/chemprop-updated/chemprop/uncertainty/evaluator.py @@ -0,0 +1,368 @@ +from abc import ABC, abstractmethod + +import numpy as np +import torch +from torch import Tensor +from torchmetrics.regression import SpearmanCorrCoef + +from chemprop.utils.registry import ClassRegistry + +UncertaintyEvaluatorRegistry = ClassRegistry() + + +class RegressionEvaluator(ABC): + """Evaluates the quality of uncertainty estimates in regression tasks.""" + + @abstractmethod + def evaluate(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + """Evaluate the performance of uncertainty predictions against the model target values. + + Parameters + ---------- + preds: Tensor + the predictions for regression tasks. It is a tensor of the shape of ``n x t``, where ``n`` is + the number of input molecules/reactions, and ``t`` is the number of tasks. + uncs: Tensor + the predicted uncertainties of the shape of ``n x t`` + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the evaluation + + Returns + ------- + Tensor + a tensor of the shape ``t`` containing the evaluated metrics + """ + + +@UncertaintyEvaluatorRegistry.register("nll-regression") +class NLLRegressionEvaluator(RegressionEvaluator): + r""" + Evaluate uncertainty values for regression datasets using the mean negative-log-likelihood + of the targets given the probability distributions estimated by the model: + + .. math:: + + \mathrm{NLL}(y, \hat y) = \frac{1}{2} \log(2 \pi \sigma^2) + \frac{(y - \hat{y})^2}{2 \sigma^2} + + where :math:`\hat{y}` is the predicted value, :math:`y` is the true value, and + :math:`\sigma^2` is the predicted uncertainty (variance). + + The function returns a tensor containing the mean NLL for each task. + """ + + def evaluate(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + nlls = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + preds_j = preds[:, j][mask_j] + targets_j = targets[:, j][mask_j] + uncs_j = uncs[:, j][mask_j] + errors = preds_j - targets_j + nll = (2 * torch.pi * uncs_j).log() / 2 + errors**2 / (2 * uncs_j) + nlls.append(nll.mean(dim=0)) + return torch.stack(nlls) + + +@UncertaintyEvaluatorRegistry.register("miscalibration_area") +class CalibrationAreaEvaluator(RegressionEvaluator): + """ + A class for evaluating regression uncertainty values based on how they deviate from perfect + calibration on an observed-probability versus expected-probability plot. + """ + + def evaluate( + self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor, num_bins: int = 100 + ) -> Tensor: + """Evaluate the performance of uncertainty predictions against the model target values. + + Parameters + ---------- + preds: Tensor + the predictions for regression tasks. It is a tensor of the shape of ``n x t``, where ``n`` is + the number of input molecules/reactions, and ``t`` is the number of tasks. + uncs: Tensor + the predicted uncertainties (variance) of the shape of ``n x t`` + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the evaluation + num_bins: int, default=100 + the number of bins to discretize the ``[0, 1]`` interval + + Returns + ------- + Tensor + a tensor of the shape ``t`` containing the evaluated metrics + """ + bins = torch.arange(1, num_bins) + bin_scaling = torch.special.erfinv(bins / num_bins).view(-1, 1, 1) * np.sqrt(2) + errors = torch.abs(preds - targets) + uncs = torch.sqrt(uncs).unsqueeze(0) + bin_unc = uncs * bin_scaling + bin_count = bin_unc >= errors.unsqueeze(0) + mask = mask.unsqueeze(0) + observed_auc = (bin_count & mask).sum(1) / mask.sum(1) + num_tasks = uncs.shape[-1] + observed_auc = torch.cat( + [torch.zeros(1, num_tasks), observed_auc, torch.ones(1, num_tasks)] + ).T + ideal_auc = torch.arange(num_bins + 1) / num_bins + miscal_area = (1 / num_bins) * (observed_auc - ideal_auc).abs().sum(dim=1) + return miscal_area + + +@UncertaintyEvaluatorRegistry.register("ence") +class ExpectedNormalizedErrorEvaluator(RegressionEvaluator): + r""" + A class that evaluates uncertainty performance by binning together clusters of predictions + and comparing the average predicted variance of the clusters against the RMSE of the cluster. [1]_ + + .. math:: + \mathrm{ENCE} = \frac{1}{N} \sum_{i=1}^{N} \frac{|\mathrm{RMV}_i - \mathrm{RMSE}_i|}{\mathrm{RMV}_i} + + where :math:`N` is the number of bins, :math:`\mathrm{RMV}_i` is the root of the mean uncertainty over the + :math:`i`-th bin and :math:`\mathrm{RMSE}_i` is the root mean square error over the :math:`i`-th bin. This + discrepancy is further normalized by the uncertainty over the bin, :math:`\mathrm{RMV}_i`, because the error + is expected to be naturally higher as the uncertainty increases. + + References + ---------- + .. [1] Levi, D.; Gispan, L.; Giladi, N.; Fetaya, E. "Evaluating and Calibrating Uncertainty Prediction in Regression Tasks." + Sensors, 2022, 22(15), 5540. https://www.mdpi.com/1424-8220/22/15/5540 + """ + + def evaluate( + self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor, num_bins: int = 100 + ) -> Tensor: + """Evaluate the performance of uncertainty predictions against the model target values. + + Parameters + ---------- + preds: Tensor + the predictions for regression tasks. It is a tensor of the shape of ``n x t``, where ``n`` is + the number of input molecules/reactions, and ``t`` is the number of tasks. + uncs: Tensor + the predicted uncertainties (variance) of the shape of ``n x t`` + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the evaluation + num_bins: int, default=100 + the number of bins the data are divided into + + Returns + ------- + Tensor + a tensor of the shape ``t`` containing the evaluated metrics + """ + masked_preds = preds * mask + masked_targets = targets * mask + masked_uncs = uncs * mask + errors = torch.abs(masked_preds - masked_targets) + + sort_idx = torch.argsort(masked_uncs, dim=0) + sorted_uncs = torch.gather(masked_uncs, 0, sort_idx) + sorted_errors = torch.gather(errors, 0, sort_idx) + + split_unc = torch.chunk(sorted_uncs, num_bins, dim=0) + split_error = torch.chunk(sorted_errors, num_bins, dim=0) + + root_mean_vars = torch.sqrt(torch.stack([chunk.mean(0) for chunk in split_unc])) + rmses = torch.sqrt(torch.stack([chunk.pow(2).mean(0) for chunk in split_error])) + + ence = torch.mean(torch.abs(root_mean_vars - rmses) / root_mean_vars, dim=0) + return ence + + +@UncertaintyEvaluatorRegistry.register("spearman") +class SpearmanEvaluator(RegressionEvaluator): + """ + Evaluate the Spearman rank correlation coefficient between the uncertainties and errors in the model predictions. + + The correlation coefficient returns a value in the [-1, 1] range, with better scores closer to 1 + observed when the uncertainty values are predictive of the rank ordering of the errors in the model prediction. + """ + + def evaluate(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + spearman_coeffs = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + preds_j = preds[:, j][mask_j] + targets_j = targets[:, j][mask_j] + uncs_j = uncs[:, j][mask_j] + errs_j = (preds_j - targets_j).abs() + spearman = SpearmanCorrCoef() + spearman_coeff = spearman(uncs_j, errs_j) + spearman_coeffs.append(spearman_coeff) + return torch.stack(spearman_coeffs) + + +@UncertaintyEvaluatorRegistry.register("conformal-coverage-regression") +class RegressionConformalEvaluator(RegressionEvaluator): + r""" + Evaluate the coverage of conformal prediction for regression datasets. + + .. math:: + \Pr (Y_{\text{test}} \in C(X_{\text{test}})) + + where the :math:`C(X_{\text{test}})` is the predicted interval. + """ + + def evaluate(self, preds: Tensor, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + bounds = torch.tensor([-1 / 2, 1 / 2], device=mask.device) + interval = uncs.unsqueeze(0) * bounds.view([-1] + [1] * preds.ndim) + lower, upper = preds.unsqueeze(0) + interval + covered_mask = torch.logical_and(lower <= targets, targets <= upper) + + return (covered_mask & mask).sum(0) / mask.sum(0) + + +class BinaryClassificationEvaluator(ABC): + """Evaluates the quality of uncertainty estimates in binary classification tasks.""" + + @abstractmethod + def evaluate(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + """Evaluate the performance of uncertainty predictions against the model target values. + + Parameters + ---------- + uncs: Tensor + the predicted uncertainties (i.e., the predicted probability of class 1) of the shape of ``n x t``, where ``n`` is the number of input + molecules/reactions, and ``t`` is the number of tasks. + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the evaluation + + Returns + ------- + Tensor + a tensor of the shape ``t`` containing the evaluated metrics + """ + + +@UncertaintyEvaluatorRegistry.register("nll-classification") +class NLLClassEvaluator(BinaryClassificationEvaluator): + """ + Evaluate uncertainty values for binary classification datasets using the mean negative-log-likelihood + of the targets given the assigned probabilities from the model: + + .. math:: + + \mathrm{NLL} = -\log(\hat{y} \cdot y + (1 - \hat{y}) \cdot (1 - y)) + + where :math:`y` is the true binary label (0 or 1), and + :math:`\hat{y}` is the predicted probability associated with the class label 1. + + The function returns a tensor containing the mean NLL for each task. + """ + + def evaluate(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + nlls = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + targets_j = targets[:, j][mask_j] + uncs_j = uncs[:, j][mask_j] + likelihood = uncs_j * targets_j + (1 - uncs_j) * (1 - targets_j) + nll = -1 * likelihood.log() + nlls.append(nll.mean(dim=0)) + return torch.stack(nlls) + + +@UncertaintyEvaluatorRegistry.register("conformal-coverage-classification") +class MultilabelConformalEvaluator(BinaryClassificationEvaluator): + r""" + Evaluate the coverage of conformal prediction for binary classification datasets with multiple labels. + + .. math:: + \Pr \left( + \hat{\mathcal C}_{\text{in}}(X) \subseteq \mathcal Y \subseteq \hat{\mathcal C}_{\text{out}}(X) + \right) + + where the in-set :math:`\hat{\mathcal C}_\text{in}` is contained by the set of true labels :math:`\mathcal Y` and + :math:`\mathcal Y` is contained within the out-set :math:`\hat{\mathcal C}_\text{out}`. + """ + + def evaluate(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + in_set, out_set = torch.chunk(uncs, 2, 1) + covered_mask = torch.logical_and(in_set <= targets, targets <= out_set) + return (covered_mask & mask).sum(0) / mask.sum(0) + + +class MulticlassClassificationEvaluator(ABC): + """Evaluates the quality of uncertainty estimates in multiclass classification tasks.""" + + @abstractmethod + def evaluate(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + """Evaluate the performance of uncertainty predictions against the model target values. + + Parameters + ---------- + uncs: Tensor + the predicted uncertainties (i.e., the predicted probabilities for each class) of the shape of ``n x t x c``, where ``n`` is the number of input + molecules/reactions, ``t`` is the number of tasks, and ``c`` is the number of classes. + targets: Tensor + a tensor of the shape ``n x t`` + mask: Tensor + a tensor of the shape ``n x t`` indicating whether the given values should be used in the evaluation + + Returns + ------- + Tensor + a tensor of the shape ``t`` containing the evaluated metrics + """ + + +@UncertaintyEvaluatorRegistry.register("nll-multiclass") +class NLLMulticlassEvaluator(MulticlassClassificationEvaluator): + """ + Evaluate uncertainty values for multiclass classification datasets using the mean negative-log-likelihood + of the targets given the assigned probabilities from the model: + + .. math:: + + \mathrm{NLL} = -\log(p_{y_i}) + + where :math:`p_{y_i}` is the predicted probability for the true class :math:`y_i`, calculated as: + + .. math:: + + p_{y_i} = \sum_{k=1}^{K} \mathbb{1}(y_i = k) \cdot p_k + + Here: :math:`K` is the total number of classes, + :math:`\mathbb{1}(y_i = k)` is the indicator function that is 1 when the true class :math:`y_i` equals class :math:`k`, and 0 otherwise, + and :math:`p_k` is the predicted probability for class :math:`k`. + + The function returns a tensor containing the mean NLL for each task. + """ + + def evaluate(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + nlls = [] + for j in range(uncs.shape[1]): + mask_j = mask[:, j] + targets_j = targets[:, j][mask_j] + uncs_j = uncs[:, j][mask_j] + targets_one_hot = torch.eye(uncs_j.shape[-1])[targets_j.long()] + likelihood = (targets_one_hot * uncs_j).sum(dim=-1) + nll = -1 * likelihood.log() + nlls.append(nll.mean(dim=0)) + return torch.stack(nlls) + + +@UncertaintyEvaluatorRegistry.register("conformal-coverage-multiclass") +class MulticlassConformalEvaluator(MulticlassClassificationEvaluator): + r""" + Evaluate the coverage of conformal prediction for multiclass classification datasets. + + .. math:: + \Pr (Y_{\text{test}} \in C(X_{\text{test}})) + + where the :math:`C(X_{\text{test}}) \subset \{1 \mathrel{.\,.} K\}` is a prediction set of possible labels . + """ + + def evaluate(self, uncs: Tensor, targets: Tensor, mask: Tensor) -> Tensor: + targets_one_hot = torch.nn.functional.one_hot(targets, num_classes=uncs.shape[2]) + covered_mask = torch.max(uncs * targets_one_hot, dim=-1)[0] > 0 + return (covered_mask & mask).sum(0) / mask.sum(0) diff --git a/chemprop-updated/chemprop/utils/__init__.py b/chemprop-updated/chemprop/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1a8937a6e06591c8b5eb19bbe5ae00851364351a --- /dev/null +++ b/chemprop-updated/chemprop/utils/__init__.py @@ -0,0 +1,4 @@ +from .registry import ClassRegistry, Factory +from .utils import EnumMapping, make_mol, pretty_shape + +__all__ = ["ClassRegistry", "Factory", "EnumMapping", "make_mol", "pretty_shape"] diff --git a/chemprop-updated/chemprop/utils/registry.py b/chemprop-updated/chemprop/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..58137351965bc80749f10654abce8d8c4d8570e0 --- /dev/null +++ b/chemprop-updated/chemprop/utils/registry.py @@ -0,0 +1,46 @@ +import inspect +from typing import Any, Iterable, Type, TypeVar + +T = TypeVar("T") + + +class ClassRegistry(dict[str, Type[T]]): + def register(self, alias: Any | Iterable[Any] | None = None): + def decorator(cls): + if alias is None: + keys = [cls.__name__.lower()] + elif isinstance(alias, str): + keys = [alias] + else: + keys = alias + + cls.alias = keys[0] + for k in keys: + self[k] = cls + + return cls + + return decorator + + __call__ = register + + def __repr__(self) -> str: # pragma: no cover + return f"{self.__class__.__name__}: {super().__repr__()}" + + def __str__(self) -> str: # pragma: no cover + INDENT = 4 + items = [f"{' ' * INDENT}{repr(k)}: {repr(v)}" for k, v in self.items()] + + return "\n".join([f"{self.__class__.__name__} {'{'}", ",\n".join(items), "}"]) + + +class Factory: + @classmethod + def build(cls, clz_T: Type[T], *args, **kwargs) -> T: + if not inspect.isclass(clz_T): + raise TypeError(f"Expected a class type! got: {type(clz_T)}") + + sig = inspect.signature(clz_T) + kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters.keys()} + + return clz_T(*args, **kwargs) diff --git a/chemprop-updated/chemprop/utils/utils.py b/chemprop-updated/chemprop/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d7cd4d0c68439563c085847c5153f66849e637b --- /dev/null +++ b/chemprop-updated/chemprop/utils/utils.py @@ -0,0 +1,86 @@ +from __future__ import annotations + +from enum import StrEnum +from typing import Iterable, Iterator + +from rdkit import Chem + + +class EnumMapping(StrEnum): + @classmethod + def get(cls, name: str | EnumMapping) -> EnumMapping: + if isinstance(name, cls): + return name + + try: + return cls[name.upper()] + except KeyError: + raise KeyError( + f"Unsupported {cls.__name__} member! got: '{name}'. expected one of: {cls.keys()}" + ) + + @classmethod + def keys(cls) -> Iterator[str]: + return (e.name for e in cls) + + @classmethod + def values(cls) -> Iterator[str]: + return (e.value for e in cls) + + @classmethod + def items(cls) -> Iterator[tuple[str, str]]: + return zip(cls.keys(), cls.values()) + + +def make_mol(smi: str, keep_h: bool, add_h: bool, ignore_chirality: bool = False) -> Chem.Mol: + """build an RDKit molecule from a SMILES string. + + Parameters + ---------- + smi : str + a SMILES string. + keep_h : bool + whether to keep hydrogens in the input smiles. This does not add hydrogens, it only keeps them if they are specified + add_h : bool + If True, adds hydrogens to the molecule. + ignore_chirality : bool, optional + If True, ignores chirality information when constructing the molecule. Default is False. + + Returns + ------- + Chem.Mol + the RDKit molecule. + """ + if keep_h: + mol = Chem.MolFromSmiles(smi, sanitize=False) + Chem.SanitizeMol( + mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ Chem.SanitizeFlags.SANITIZE_ADJUSTHS + ) + else: + mol = Chem.MolFromSmiles(smi) + + if mol is None: + raise RuntimeError(f"SMILES {smi} is invalid! (RDKit returned None)") + + if add_h: + mol = Chem.AddHs(mol) + + if ignore_chirality: + for atom in mol.GetAtoms(): + atom.SetChiralTag(Chem.ChiralType.CHI_UNSPECIFIED) + + return mol + + +def pretty_shape(shape: Iterable[int]) -> str: + """Make a pretty string from an input shape + + Example + -------- + >>> X = np.random.rand(10, 4) + >>> X.shape + (10, 4) + >>> pretty_shape(X.shape) + '10 x 4' + """ + return " x ".join(map(str, shape)) diff --git a/chemprop-updated/chemprop/utils/v1_to_v2.py b/chemprop-updated/chemprop/utils/v1_to_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..cd059340a147db7cd383fac6d60ebf1cf87debb2 --- /dev/null +++ b/chemprop-updated/chemprop/utils/v1_to_v2.py @@ -0,0 +1,188 @@ +from os import PathLike + +from lightning.pytorch import __version__ +from lightning.pytorch.utilities.parsing import AttributeDict +import torch + +from chemprop.nn.agg import AggregationRegistry +from chemprop.nn.message_passing import AtomMessagePassing, BondMessagePassing +from chemprop.nn.metrics import LossFunctionRegistry, MetricRegistry +from chemprop.nn.predictors import PredictorRegistry +from chemprop.nn.transforms import UnscaleTransform +from chemprop.utils import Factory + + +def convert_state_dict_v1_to_v2(model_v1_dict: dict) -> dict: + """Converts v1 model dictionary to a v2 state dictionary""" + + state_dict_v2 = {} + args_v1 = model_v1_dict["args"] + + state_dict_v1 = model_v1_dict["state_dict"] + state_dict_v2["message_passing.W_i.weight"] = state_dict_v1["encoder.encoder.0.W_i.weight"] + state_dict_v2["message_passing.W_h.weight"] = state_dict_v1["encoder.encoder.0.W_h.weight"] + state_dict_v2["message_passing.W_o.weight"] = state_dict_v1["encoder.encoder.0.W_o.weight"] + state_dict_v2["message_passing.W_o.bias"] = state_dict_v1["encoder.encoder.0.W_o.bias"] + + # v1.6 renamed ffn to readout + if "readout.1.weight" in state_dict_v1: + for i in range(args_v1.ffn_num_layers): + suffix = 0 if i == 0 else 2 + state_dict_v2[f"predictor.ffn.{i}.{suffix}.weight"] = state_dict_v1[ + f"readout.{i * 3 + 1}.weight" + ] + state_dict_v2[f"predictor.ffn.{i}.{suffix}.bias"] = state_dict_v1[ + f"readout.{i * 3 + 1}.bias" + ] + else: + for i in range(args_v1.ffn_num_layers): + suffix = 0 if i == 0 else 2 + state_dict_v2[f"predictor.ffn.{i}.{suffix}.weight"] = state_dict_v1[ + f"ffn.{i * 3 + 1}.weight" + ] + state_dict_v2[f"predictor.ffn.{i}.{suffix}.bias"] = state_dict_v1[ + f"ffn.{i * 3 + 1}.bias" + ] + + if args_v1.dataset_type == "regression": + state_dict_v2["predictor.output_transform.mean"] = torch.tensor( + model_v1_dict["data_scaler"]["means"], dtype=torch.float32 + ).unsqueeze(0) + state_dict_v2["predictor.output_transform.scale"] = torch.tensor( + model_v1_dict["data_scaler"]["stds"], dtype=torch.float32 + ).unsqueeze(0) + + # target_weights was added in #183 + if getattr(args_v1, "target_weights", None) is not None: + task_weights = torch.tensor(args_v1.target_weights).unsqueeze(0) + else: + task_weights = torch.ones(args_v1.num_tasks).unsqueeze(0) + + state_dict_v2["predictor.criterion.task_weights"] = task_weights + + return state_dict_v2 + + +def convert_hyper_parameters_v1_to_v2(model_v1_dict: dict) -> dict: + """Converts v1 model dictionary to v2 hyper_parameters dictionary""" + hyper_parameters_v2 = {} + renamed_metrics = { + "auc": "roc", + "prc-auc": "prc", + "cross_entropy": "ce", + "binary_cross_entropy": "bce", + "mcc": "binary-mcc", + "recall": "recall is not in v2", + "precision": "precision is not in v2", + "balanced_accuracy": "balanced_accuracy is not in v2", + } + + args_v1 = model_v1_dict["args"] + hyper_parameters_v2["batch_norm"] = False + hyper_parameters_v2["metrics"] = [ + Factory.build(MetricRegistry[renamed_metrics.get(args_v1.metric, args_v1.metric)]) + ] + hyper_parameters_v2["warmup_epochs"] = args_v1.warmup_epochs + hyper_parameters_v2["init_lr"] = args_v1.init_lr + hyper_parameters_v2["max_lr"] = args_v1.max_lr + hyper_parameters_v2["final_lr"] = args_v1.final_lr + + # convert the message passing block + W_i_shape = model_v1_dict["state_dict"]["encoder.encoder.0.W_i.weight"].shape + W_h_shape = model_v1_dict["state_dict"]["encoder.encoder.0.W_h.weight"].shape + W_o_shape = model_v1_dict["state_dict"]["encoder.encoder.0.W_o.weight"].shape + + d_h = W_i_shape[0] + d_v = W_o_shape[1] - d_h + d_e = W_h_shape[1] - d_h if args_v1.atom_messages else W_i_shape[1] - d_v + + hyper_parameters_v2["message_passing"] = AttributeDict( + { + "activation": args_v1.activation, + "bias": args_v1.bias, + "cls": BondMessagePassing if not args_v1.atom_messages else AtomMessagePassing, + "d_e": d_e, # the feature dimension of the edges + "d_h": args_v1.hidden_size, # dimension of the hidden layer + "d_v": d_v, # the feature dimension of the vertices + "d_vd": args_v1.atom_descriptors_size, + "depth": args_v1.depth, + "dropout": args_v1.dropout, + "undirected": args_v1.undirected, + } + ) + + # convert the aggregation block + hyper_parameters_v2["agg"] = { + "dim": 0, # in v1, the aggregation is always done on the atom features + "cls": AggregationRegistry[args_v1.aggregation], + } + if args_v1.aggregation == "norm": + hyper_parameters_v2["agg"]["norm"] = args_v1.aggregation_norm + + # convert the predictor block + fgs = args_v1.features_generator or [] + d_xd = sum((200 if "rdkit" in fg else 0) + (2048 if "morgan" in fg else 0) for fg in fgs) + + if getattr(args_v1, "target_weights", None) is not None: + task_weights = torch.tensor(args_v1.target_weights).unsqueeze(0) + else: + task_weights = torch.ones(args_v1.num_tasks).unsqueeze(0) + + # loss_function was added in #238 + loss_fn_defaults = { + "classification": "bce", + "regression": "mse", + "multiclass": "ce", + "specitra": "sid", + } + T_loss_fn = LossFunctionRegistry[ + getattr(args_v1, "loss_function", loss_fn_defaults[args_v1.dataset_type]) + ] + + hyper_parameters_v2["predictor"] = AttributeDict( + { + "activation": args_v1.activation, + "cls": PredictorRegistry[args_v1.dataset_type], + "criterion": Factory.build(T_loss_fn, task_weights=task_weights), + "task_weights": None, + "dropout": args_v1.dropout, + "hidden_dim": args_v1.ffn_hidden_size, + "input_dim": args_v1.hidden_size + args_v1.atom_descriptors_size + d_xd, + "n_layers": args_v1.ffn_num_layers - 1, + "n_tasks": args_v1.num_tasks, + } + ) + + if args_v1.dataset_type == "regression": + hyper_parameters_v2["predictor"]["output_transform"] = UnscaleTransform( + model_v1_dict["data_scaler"]["means"], model_v1_dict["data_scaler"]["stds"] + ) + + return hyper_parameters_v2 + + +def convert_model_dict_v1_to_v2(model_v1_dict: dict) -> dict: + """Converts a v1 model dictionary from a loaded .pt file to a v2 model dictionary""" + + model_v2_dict = {} + + model_v2_dict["epoch"] = None + model_v2_dict["global_step"] = None + model_v2_dict["pytorch-lightning_version"] = __version__ + model_v2_dict["state_dict"] = convert_state_dict_v1_to_v2(model_v1_dict) + model_v2_dict["loops"] = None + model_v2_dict["callbacks"] = None + model_v2_dict["optimizer_states"] = None + model_v2_dict["lr_schedulers"] = None + model_v2_dict["hparams_name"] = "kwargs" + model_v2_dict["hyper_parameters"] = convert_hyper_parameters_v1_to_v2(model_v1_dict) + + return model_v2_dict + + +def convert_model_file_v1_to_v2(model_v1_file: PathLike, model_v2_file: PathLike) -> None: + """Converts a v1 model .pt file to a v2 model .pt file""" + + model_v1_dict = torch.load(model_v1_file, map_location=torch.device("cpu"), weights_only=False) + model_v2_dict = convert_model_dict_v1_to_v2(model_v1_dict) + torch.save(model_v2_dict, model_v2_file) diff --git a/chemprop-updated/chemprop/utils/v2_0_to_v2_1.py b/chemprop-updated/chemprop/utils/v2_0_to_v2_1.py new file mode 100644 index 0000000000000000000000000000000000000000..8627637bc63d6594547f8b4f401e52d43808bb16 --- /dev/null +++ b/chemprop-updated/chemprop/utils/v2_0_to_v2_1.py @@ -0,0 +1,40 @@ +import pickle +import sys + +import torch + + +class Unpickler(pickle.Unpickler): + name_mappings = { + "MSELoss": "MSE", + "MSEMetric": "MSE", + "MAEMetric": "MAE", + "RMSEMetric": "RMSE", + "BoundedMSELoss": "BoundedMSE", + "BoundedMSEMetric": "BoundedMSE", + "BoundedMAEMetric": "BoundedMAE", + "BoundedRMSEMetric": "BoundedRMSE", + "SIDLoss": "SID", + "SIDMetric": "SID", + "WassersteinLoss": "Wasserstein", + "WassersteinMetric": "Wasserstein", + "R2Metric": "R2Score", + "BinaryAUROCMetric": "BinaryAUROC", + "BinaryAUPRCMetric": "BinaryAUPRC", + "BinaryAccuracyMetric": "BinaryAccuracy", + "BinaryF1Metric": "BinaryF1Score", + "BCEMetric": "BCELoss", + } + + def find_class(self, module, name): + if module == "chemprop.nn.loss": + module = "chemprop.nn.metrics" + name = self.name_mappings.get(name, name) + return super().find_class(module, name) + + +if __name__ == "__main__": + model = torch.load( + sys.argv[1], map_location="cpu", pickle_module=sys.modules[__name__], weights_only=False + ) + torch.save(model, sys.argv[2]) diff --git a/chemprop-updated/chemprop/web/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/web/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e583789216f0a21aa82e5c4521bfe4e7e2e70979 Binary files /dev/null and b/chemprop-updated/chemprop/web/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/web/__pycache__/config.cpython-37.pyc b/chemprop-updated/chemprop/web/__pycache__/config.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b98aa54568a30ac337c685689bd24259deaa5bc Binary files /dev/null and b/chemprop-updated/chemprop/web/__pycache__/config.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/web/__pycache__/run.cpython-37.pyc b/chemprop-updated/chemprop/web/__pycache__/run.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9244cc7bcf7465d40584fdbdb7b94d816832035a Binary files /dev/null and b/chemprop-updated/chemprop/web/__pycache__/run.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/web/__pycache__/utils.cpython-37.pyc b/chemprop-updated/chemprop/web/__pycache__/utils.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98bc7d90a7a7819fed797504bcc66e32deac6e61 Binary files /dev/null and b/chemprop-updated/chemprop/web/__pycache__/utils.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/web/app/__pycache__/__init__.cpython-37.pyc b/chemprop-updated/chemprop/web/app/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b429c2989b044e1459612ef2e200beaf4b58720 Binary files /dev/null and b/chemprop-updated/chemprop/web/app/__pycache__/__init__.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/web/app/__pycache__/db.cpython-37.pyc b/chemprop-updated/chemprop/web/app/__pycache__/db.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0771eb4cfafc14b740d8b5a12d037816f94fa006 Binary files /dev/null and b/chemprop-updated/chemprop/web/app/__pycache__/db.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/web/app/__pycache__/views.cpython-37.pyc b/chemprop-updated/chemprop/web/app/__pycache__/views.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7afc9091c06166c507ae72789ba64b4a8332a283 Binary files /dev/null and b/chemprop-updated/chemprop/web/app/__pycache__/views.cpython-37.pyc differ diff --git a/chemprop-updated/chemprop/web/chemprop.sqlite3 b/chemprop-updated/chemprop/web/chemprop.sqlite3 new file mode 100644 index 0000000000000000000000000000000000000000..a4adae685e0264b067ded01dd9ca6099c06a2c9b Binary files /dev/null and b/chemprop-updated/chemprop/web/chemprop.sqlite3 differ