hbhzm commited on
Commit
1afebd5
·
verified ·
1 Parent(s): 41e2665

Upload 111 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. chemprop-updated/chemprop/__init__.py +5 -0
  2. chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc +0 -0
  3. chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc +0 -0
  4. chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc +0 -0
  5. chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc +0 -0
  6. chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc +0 -0
  7. chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc +0 -0
  8. chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc +0 -0
  9. chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc +0 -0
  10. chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc +0 -0
  11. chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc +0 -0
  12. chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc +0 -0
  13. chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc +0 -0
  14. chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc +0 -0
  15. chemprop-updated/chemprop/cli/common.py +216 -0
  16. chemprop-updated/chemprop/cli/conf.py +9 -0
  17. chemprop-updated/chemprop/cli/convert.py +55 -0
  18. chemprop-updated/chemprop/cli/fingerprint.py +185 -0
  19. chemprop-updated/chemprop/cli/hpopt.py +540 -0
  20. chemprop-updated/chemprop/cli/main.py +85 -0
  21. chemprop-updated/chemprop/cli/predict.py +447 -0
  22. chemprop-updated/chemprop/cli/train.py +1343 -0
  23. chemprop-updated/chemprop/cli/utils/__init__.py +30 -0
  24. chemprop-updated/chemprop/cli/utils/actions.py +19 -0
  25. chemprop-updated/chemprop/cli/utils/args.py +34 -0
  26. chemprop-updated/chemprop/cli/utils/command.py +24 -0
  27. chemprop-updated/chemprop/cli/utils/parsing.py +457 -0
  28. chemprop-updated/chemprop/cli/utils/utils.py +31 -0
  29. chemprop-updated/chemprop/conf.py +6 -0
  30. chemprop-updated/chemprop/data/__init__.py +41 -0
  31. chemprop-updated/chemprop/data/__pycache__/__init__.cpython-37.pyc +0 -0
  32. chemprop-updated/chemprop/data/__pycache__/data.cpython-37.pyc +0 -0
  33. chemprop-updated/chemprop/data/__pycache__/scaffold.cpython-37.pyc +0 -0
  34. chemprop-updated/chemprop/data/__pycache__/scaler.cpython-37.pyc +0 -0
  35. chemprop-updated/chemprop/data/__pycache__/utils.cpython-37.pyc +0 -0
  36. chemprop-updated/chemprop/data/collate.py +123 -0
  37. chemprop-updated/chemprop/data/dataloader.py +71 -0
  38. chemprop-updated/chemprop/data/datapoints.py +150 -0
  39. chemprop-updated/chemprop/data/datasets.py +475 -0
  40. chemprop-updated/chemprop/data/molgraph.py +17 -0
  41. chemprop-updated/chemprop/data/samplers.py +66 -0
  42. chemprop-updated/chemprop/data/splitting.py +225 -0
  43. chemprop-updated/chemprop/exceptions.py +12 -0
  44. chemprop-updated/chemprop/features/__pycache__/__init__.cpython-37.pyc +0 -0
  45. chemprop-updated/chemprop/features/__pycache__/features_generators.cpython-37.pyc +0 -0
  46. chemprop-updated/chemprop/features/__pycache__/featurization.cpython-37.pyc +0 -0
  47. chemprop-updated/chemprop/features/__pycache__/utils.cpython-37.pyc +0 -0
  48. chemprop-updated/chemprop/featurizers/__init__.py +52 -0
  49. chemprop-updated/chemprop/featurizers/atom.py +281 -0
  50. chemprop-updated/chemprop/featurizers/base.py +30 -0
chemprop-updated/chemprop/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import data, exceptions, featurizers, models, nn, schedulers, utils
2
+
3
+ __all__ = ["data", "featurizers", "models", "nn", "utils", "exceptions", "schedulers"]
4
+
5
+ __version__ = "2.1.2"
chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (743 Bytes). View file
 
chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc ADDED
Binary file (33.7 kB). View file
 
chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc ADDED
Binary file (430 Bytes). View file
 
chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc ADDED
Binary file (11.1 kB). View file
 
chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc ADDED
Binary file (6.15 kB). View file
 
chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc ADDED
Binary file (14.2 kB). View file
 
chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc ADDED
Binary file (3.12 kB). View file
 
chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc ADDED
Binary file (8.13 kB). View file
 
chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc ADDED
Binary file (1.43 kB). View file
 
chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc ADDED
Binary file (2.82 kB). View file
 
chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc ADDED
Binary file (11.4 kB). View file
 
chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc ADDED
Binary file (5.1 kB). View file
 
chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc ADDED
Binary file (26.5 kB). View file
 
chemprop-updated/chemprop/cli/common.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ from chemprop.cli.utils import LookupAction
6
+ from chemprop.cli.utils.args import uppercase
7
+ from chemprop.featurizers import AtomFeatureMode, MoleculeFeaturizerRegistry, RxnMode
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def add_common_args(parser: ArgumentParser) -> ArgumentParser:
13
+ data_args = parser.add_argument_group("Shared input data args")
14
+ data_args.add_argument(
15
+ "-s",
16
+ "--smiles-columns",
17
+ nargs="+",
18
+ help="Column names in the input CSV containing SMILES strings (uses the 0th column by default)",
19
+ )
20
+ data_args.add_argument(
21
+ "-r",
22
+ "--reaction-columns",
23
+ nargs="+",
24
+ help="Column names in the input CSV containing reaction SMILES in the format ``REACTANT>AGENT>PRODUCT``, where 'AGENT' is optional",
25
+ )
26
+ data_args.add_argument(
27
+ "--no-header-row",
28
+ action="store_true",
29
+ help="Turn off using the first row in the input CSV as column names",
30
+ )
31
+
32
+ dataloader_args = parser.add_argument_group("Dataloader args")
33
+ dataloader_args.add_argument(
34
+ "-n",
35
+ "--num-workers",
36
+ type=int,
37
+ default=0,
38
+ help="""Number of workers for parallel data loading where 0 means sequential
39
+ (Warning: setting ``num_workers`` to a value greater than 0 can cause hangs on Windows and MacOS)""",
40
+ )
41
+ dataloader_args.add_argument("-b", "--batch-size", type=int, default=64, help="Batch size")
42
+
43
+ parser.add_argument(
44
+ "--accelerator", default="auto", help="Passed directly to the lightning ``Trainer()``"
45
+ )
46
+ parser.add_argument(
47
+ "--devices",
48
+ default="auto",
49
+ help="Passed directly to the lightning ``Trainer()`` (must be a single string of comma separated devices, e.g. '1, 2' if specifying multiple devices)",
50
+ )
51
+
52
+ featurization_args = parser.add_argument_group("Featurization args")
53
+ featurization_args.add_argument(
54
+ "--rxn-mode",
55
+ "--reaction-mode",
56
+ type=uppercase,
57
+ default="REAC_DIFF",
58
+ choices=list(RxnMode.keys()),
59
+ help="""Choices for construction of atom and bond features for reactions (case insensitive):
60
+
61
+ - ``REAC_PROD``: concatenates the reactants feature with the products feature
62
+ - ``REAC_DIFF``: concatenates the reactants feature with the difference in features between reactants and products (Default)
63
+ - ``PROD_DIFF``: concatenates the products feature with the difference in features between reactants and products
64
+ - ``REAC_PROD_BALANCE``: concatenates the reactants feature with the products feature, balances imbalanced reactions
65
+ - ``REAC_DIFF_BALANCE``: concatenates the reactants feature with the difference in features between reactants and products, balances imbalanced reactions
66
+ - ``PROD_DIFF_BALANCE``: concatenates the products feature with the difference in features between reactants and products, balances imbalanced reactions""",
67
+ )
68
+ # TODO: Update documenation for multi_hot_atom_featurizer_mode
69
+ featurization_args.add_argument(
70
+ "--multi-hot-atom-featurizer-mode",
71
+ type=uppercase,
72
+ default="V2",
73
+ choices=list(AtomFeatureMode.keys()),
74
+ help="""Choices for multi-hot atom featurization scheme. This will affect both non-reaction and reaction feturization (case insensitive):
75
+
76
+ - ``V1``: Corresponds to the original configuration employed in the Chemprop V1
77
+ - ``V2``: Tailored for a broad range of molecules, this configuration encompasses all elements in the first four rows of the periodic table, along with iodine. It is the default in Chemprop V2.
78
+ - ``ORGANIC``: This configuration is designed specifically for use with organic molecules for drug research and development and includes a subset of elements most common in organic chemistry, including H, B, C, N, O, F, Si, P, S, Cl, Br, and I.
79
+ - ``RIGR``: Modified V2 (default) featurizer using only the resonance-invariant atom and bond features.""",
80
+ )
81
+ featurization_args.add_argument(
82
+ "--keep-h",
83
+ action="store_true",
84
+ help="Whether hydrogens explicitly specified in input should be kept in the mol graph",
85
+ )
86
+ featurization_args.add_argument(
87
+ "--add-h", action="store_true", help="Whether hydrogens should be added to the mol graph"
88
+ )
89
+ data_args.add_argument(
90
+ "--ignore-chirality",
91
+ action="store_true",
92
+ help="Ignore chirality information in the input SMILES",
93
+ )
94
+ featurization_args.add_argument(
95
+ "--molecule-featurizers",
96
+ "--features-generators",
97
+ nargs="+",
98
+ action=LookupAction(MoleculeFeaturizerRegistry),
99
+ help="Method(s) of generating molecule features to use as extra descriptors",
100
+ )
101
+ # TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
102
+ # featurization_args.add_argument(
103
+ # "--features-generators", nargs="+", help="Renamed to `--molecule-featurizers`."
104
+ # )
105
+ featurization_args.add_argument(
106
+ "--descriptors-path",
107
+ type=Path,
108
+ help="Path to extra descriptors to concatenate to learned representation",
109
+ )
110
+ # TODO: Add in v2.1
111
+ # featurization_args.add_argument(
112
+ # "--phase-features-path",
113
+ # help="Path to features used to indicate the phase of the data in one-hot vector form. Used in spectra datatype.",
114
+ # )
115
+ featurization_args.add_argument(
116
+ "--no-descriptor-scaling", action="store_true", help="Turn off extra descriptor scaling"
117
+ )
118
+ featurization_args.add_argument(
119
+ "--no-atom-feature-scaling", action="store_true", help="Turn off extra atom feature scaling"
120
+ )
121
+ featurization_args.add_argument(
122
+ "--no-atom-descriptor-scaling",
123
+ action="store_true",
124
+ help="Turn off extra atom descriptor scaling",
125
+ )
126
+ featurization_args.add_argument(
127
+ "--no-bond-feature-scaling", action="store_true", help="Turn off extra bond feature scaling"
128
+ )
129
+ featurization_args.add_argument(
130
+ "--atom-features-path",
131
+ nargs="+",
132
+ action="append",
133
+ help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom features to supply before message passing (e.g., ``--atom-features-path 0 /path/to/features_0.npz``) indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-features-path [...] --atom-features-path [...]``).",
134
+ )
135
+ featurization_args.add_argument(
136
+ "--atom-descriptors-path",
137
+ nargs="+",
138
+ action="append",
139
+ help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom descriptors to supply after message passing (e.g., ``--atom-descriptors-path 0 /path/to/descriptors_0.npz`` indicates that the descriptors at the given path should be supplied to the 0-th component. To supply additional descriptors for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-descriptors-path [...] --atom-descriptors-path [...]``).",
140
+ )
141
+ featurization_args.add_argument(
142
+ "--bond-features-path",
143
+ nargs="+",
144
+ action="append",
145
+ help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional bond features to supply before message passing (e.g., ``--bond-features-path 0 /path/to/features_0.npz`` indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--bond-features-path [...] --bond-features-path [...]``).",
146
+ )
147
+ # TODO: Add in v2.2
148
+ # parser.add_argument(
149
+ # "--constraints-path",
150
+ # help="Path to constraints applied to atomic/bond properties prediction.",
151
+ # )
152
+
153
+ return parser
154
+
155
+
156
+ def process_common_args(args: Namespace) -> Namespace:
157
+ # TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
158
+ # if args.features_generators is not None:
159
+ # raise ArgumentError(
160
+ # argument=None,
161
+ # message="`--features-generators` has been renamed to `--molecule-featurizers`.",
162
+ # )
163
+
164
+ for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
165
+ inds_paths = getattr(args, key)
166
+
167
+ if not inds_paths:
168
+ continue
169
+
170
+ ind_path_dict = {}
171
+
172
+ for ind_path in inds_paths:
173
+ if len(ind_path) > 2:
174
+ raise ArgumentError(
175
+ argument=None,
176
+ message="Too many arguments were given for atom features/descriptors or bond features. It should be either a two-tuple of molecule index and a path, or a single path (assumed to be the 0-th molecule).",
177
+ )
178
+
179
+ if len(ind_path) == 1:
180
+ ind = 0
181
+ path = ind_path[0]
182
+ else:
183
+ ind, path = ind_path
184
+
185
+ if ind_path_dict.get(int(ind), None):
186
+ raise ArgumentError(
187
+ argument=None,
188
+ message=f"Duplicate atom features/descriptors or bond features given for molecule index {ind}",
189
+ )
190
+
191
+ ind_path_dict[int(ind)] = Path(path)
192
+
193
+ setattr(args, key, ind_path_dict)
194
+
195
+ return args
196
+
197
+
198
+ def validate_common_args(args):
199
+ pass
200
+
201
+
202
+ def find_models(model_paths: list[Path]):
203
+ collected_model_paths = []
204
+
205
+ for model_path in model_paths:
206
+ if model_path.suffix in [".ckpt", ".pt"]:
207
+ collected_model_paths.append(model_path)
208
+ elif model_path.is_dir():
209
+ collected_model_paths.extend(list(model_path.rglob("*.pt")))
210
+ else:
211
+ raise ArgumentError(
212
+ argument=None,
213
+ message=f"Expected a .ckpt or .pt file, or a directory. Got {model_path}",
214
+ )
215
+
216
+ return collected_model_paths
chemprop-updated/chemprop/cli/conf.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+
6
+ LOG_DIR = Path(os.getenv("CHEMPROP_LOG_DIR", "chemprop_logs"))
7
+ LOG_LEVELS = {0: logging.INFO, 1: logging.DEBUG, -1: logging.WARNING, -2: logging.ERROR}
8
+ NOW = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
9
+ CHEMPROP_TRAIN_DIR = Path(os.getenv("CHEMPROP_TRAIN_DIR", "chemprop_training"))
chemprop-updated/chemprop/cli/convert.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+ import sys
5
+
6
+ from chemprop.cli.utils import Subcommand
7
+ from chemprop.utils.v1_to_v2 import convert_model_file_v1_to_v2
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ConvertSubcommand(Subcommand):
13
+ COMMAND = "convert"
14
+ HELP = "Convert a v1 model checkpoint (.pt) to a v2 model checkpoint (.pt)."
15
+
16
+ @classmethod
17
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
18
+ parser.add_argument(
19
+ "-i",
20
+ "--input-path",
21
+ required=True,
22
+ type=Path,
23
+ help="Path to a v1 model .pt checkpoint file",
24
+ )
25
+ parser.add_argument(
26
+ "-o",
27
+ "--output-path",
28
+ type=Path,
29
+ help="Path to which the converted model will be saved (``CURRENT_DIRECTORY/STEM_OF_INPUT_v2.pt`` by default)",
30
+ )
31
+ return parser
32
+
33
+ @classmethod
34
+ def func(cls, args: Namespace):
35
+ if args.output_path is None:
36
+ args.output_path = Path(args.input_path.stem + "_v2.pt")
37
+ if args.output_path.suffix != ".pt":
38
+ raise ArgumentError(
39
+ argument=None, message=f"Output must be a `.pt` file. Got {args.output_path}"
40
+ )
41
+
42
+ logger.info(
43
+ f"Converting v1 model checkpoint '{args.input_path}' to v2 model checkpoint '{args.output_path}'..."
44
+ )
45
+ convert_model_file_v1_to_v2(args.input_path, args.output_path)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ parser = ArgumentParser()
50
+ parser = ConvertSubcommand.add_args(parser)
51
+
52
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
53
+
54
+ args = parser.parse_args()
55
+ ConvertSubcommand.func(args)
chemprop-updated/chemprop/cli/fingerprint.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+ import sys
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+
10
+ from chemprop import data
11
+ from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
12
+ from chemprop.cli.predict import find_models
13
+ from chemprop.cli.utils import Subcommand, build_data_from_files, make_dataset
14
+ from chemprop.models import load_model
15
+ from chemprop.nn.metrics import LossFunctionRegistry
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class FingerprintSubcommand(Subcommand):
21
+ COMMAND = "fingerprint"
22
+ HELP = "Use a pretrained chemprop model to calculate learned representations."
23
+
24
+ @classmethod
25
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
26
+ parser = add_common_args(parser)
27
+ parser.add_argument(
28
+ "-i",
29
+ "--test-path",
30
+ required=True,
31
+ type=Path,
32
+ help="Path to an input CSV file containing SMILES",
33
+ )
34
+ parser.add_argument(
35
+ "-o",
36
+ "--output",
37
+ "--preds-path",
38
+ type=Path,
39
+ help="Specify the path where predictions will be saved. If the file extension is .npz, they will be saved as a npz file. Otherwise, the predictions will be saved as a CSV. The index of the model will be appended to the filename's stem. By default, predictions will be saved to the same location as ``--test-path`` with '_fps' appended (e.g., 'PATH/TO/TEST_PATH_fps_0.csv').",
40
+ )
41
+ parser.add_argument(
42
+ "--model-paths",
43
+ "--model-path",
44
+ required=True,
45
+ type=Path,
46
+ nargs="+",
47
+ help="Specify location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, chemprop will recursively search and predict on all found (.pt) models.",
48
+ )
49
+ parser.add_argument(
50
+ "--ffn-block-index",
51
+ required=True,
52
+ type=int,
53
+ default=-1,
54
+ help="The index indicates which linear layer returns the encoding in the FFN. An index of 0 denotes the post-aggregation representation through a 0-layer MLP, while an index of 1 represents the output from the first linear layer in the FFN, and so forth.",
55
+ )
56
+
57
+ return parser
58
+
59
+ @classmethod
60
+ def func(cls, args: Namespace):
61
+ args = process_common_args(args)
62
+ validate_common_args(args)
63
+ args = process_fingerprint_args(args)
64
+ main(args)
65
+
66
+
67
+ def process_fingerprint_args(args: Namespace) -> Namespace:
68
+ if args.test_path.suffix not in [".csv"]:
69
+ raise ArgumentError(
70
+ argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
71
+ )
72
+ if args.output is None:
73
+ args.output = args.test_path.parent / (args.test_path.stem + "_fps.csv")
74
+ if args.output.suffix not in [".csv", ".npz"]:
75
+ raise ArgumentError(
76
+ argument=None, message=f"Output must be a CSV or NPZ file. Got '{args.output}'."
77
+ )
78
+ return args
79
+
80
+
81
+ def make_fingerprint_for_model(
82
+ args: Namespace, model_path: Path, multicomponent: bool, output_path: Path
83
+ ):
84
+ model = load_model(model_path, multicomponent)
85
+ model.eval()
86
+
87
+ bounded = any(
88
+ isinstance(model.criterion, LossFunctionRegistry[loss_function])
89
+ for loss_function in LossFunctionRegistry.keys()
90
+ if "bounded" in loss_function
91
+ )
92
+
93
+ format_kwargs = dict(
94
+ no_header_row=args.no_header_row,
95
+ smiles_cols=args.smiles_columns,
96
+ rxn_cols=args.reaction_columns,
97
+ target_cols=[],
98
+ ignore_cols=None,
99
+ splits_col=None,
100
+ weight_col=None,
101
+ bounded=bounded,
102
+ )
103
+
104
+ featurization_kwargs = dict(
105
+ molecule_featurizers=args.molecule_featurizers,
106
+ keep_h=args.keep_h,
107
+ add_h=args.add_h,
108
+ ignore_chirality=args.ignore_chirality,
109
+ )
110
+
111
+ test_data = build_data_from_files(
112
+ args.test_path,
113
+ **format_kwargs,
114
+ p_descriptors=args.descriptors_path,
115
+ p_atom_feats=args.atom_features_path,
116
+ p_bond_feats=args.bond_features_path,
117
+ p_atom_descs=args.atom_descriptors_path,
118
+ **featurization_kwargs,
119
+ )
120
+ logger.info(f"test size: {len(test_data[0])}")
121
+ test_dsets = [
122
+ make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in test_data
123
+ ]
124
+
125
+ if multicomponent:
126
+ test_dset = data.MulticomponentDataset(test_dsets)
127
+ else:
128
+ test_dset = test_dsets[0]
129
+
130
+ test_loader = data.build_dataloader(test_dset, args.batch_size, args.num_workers, shuffle=False)
131
+
132
+ logger.info(model)
133
+
134
+ with torch.no_grad():
135
+ if multicomponent:
136
+ encodings = [
137
+ model.encoding(batch.bmgs, batch.V_ds, batch.X_d, args.ffn_block_index)
138
+ for batch in test_loader
139
+ ]
140
+ else:
141
+ encodings = [
142
+ model.encoding(batch.bmg, batch.V_d, batch.X_d, args.ffn_block_index)
143
+ for batch in test_loader
144
+ ]
145
+ H = torch.cat(encodings, 0).numpy()
146
+
147
+ if output_path.suffix in [".npz"]:
148
+ np.savez(output_path, H=H)
149
+ elif output_path.suffix == ".csv":
150
+ fingerprint_columns = [f"fp_{i}" for i in range(H.shape[1])]
151
+ df_fingerprints = pd.DataFrame(H, columns=fingerprint_columns)
152
+ df_fingerprints.to_csv(output_path, index=False)
153
+ else:
154
+ raise ArgumentError(
155
+ argument=None, message=f"Output must be a CSV or npz file. Got {args.output}."
156
+ )
157
+ logger.info(f"Fingerprints saved to '{output_path}'")
158
+
159
+
160
+ def main(args):
161
+ match (args.smiles_columns, args.reaction_columns):
162
+ case [None, None]:
163
+ n_components = 1
164
+ case [_, None]:
165
+ n_components = len(args.smiles_columns)
166
+ case [None, _]:
167
+ n_components = len(args.reaction_columns)
168
+ case _:
169
+ n_components = len(args.smiles_columns) + len(args.reaction_columns)
170
+
171
+ multicomponent = n_components > 1
172
+
173
+ for i, model_path in enumerate(find_models(args.model_paths)):
174
+ logger.info(f"Fingerprints with model {i} at '{model_path}'")
175
+ output_path = args.output.parent / f"{args.output.stem}_{i}{args.output.suffix}"
176
+ make_fingerprint_for_model(args, model_path, multicomponent, output_path)
177
+
178
+
179
+ if __name__ == "__main__":
180
+ parser = ArgumentParser()
181
+ parser = FingerprintSubcommand.add_args(parser)
182
+
183
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
184
+ args = parser.parse_args()
185
+ args = FingerprintSubcommand.func(args)
chemprop-updated/chemprop/cli/hpopt.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import logging
3
+ from pathlib import Path
4
+ import shutil
5
+ import sys
6
+
7
+ from configargparse import ArgumentParser, Namespace
8
+ from lightning import pytorch as pl
9
+ from lightning.pytorch.callbacks import EarlyStopping
10
+ import numpy as np
11
+ import torch
12
+
13
+ from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
14
+ from chemprop.cli.train import (
15
+ TrainSubcommand,
16
+ add_train_args,
17
+ build_datasets,
18
+ build_model,
19
+ build_splits,
20
+ normalize_inputs,
21
+ process_train_args,
22
+ save_config,
23
+ validate_train_args,
24
+ )
25
+ from chemprop.cli.utils.command import Subcommand
26
+ from chemprop.data import build_dataloader
27
+ from chemprop.nn import AggregationRegistry, MetricRegistry
28
+ from chemprop.nn.transforms import UnscaleTransform
29
+ from chemprop.nn.utils import Activation
30
+
31
+ NO_RAY = False
32
+ DEFAULT_SEARCH_SPACE = {
33
+ "activation": None,
34
+ "aggregation": None,
35
+ "aggregation_norm": None,
36
+ "batch_size": None,
37
+ "depth": None,
38
+ "dropout": None,
39
+ "ffn_hidden_dim": None,
40
+ "ffn_num_layers": None,
41
+ "final_lr_ratio": None,
42
+ "message_hidden_dim": None,
43
+ "init_lr_ratio": None,
44
+ "max_lr": None,
45
+ "warmup_epochs": None,
46
+ }
47
+
48
+ try:
49
+ import ray
50
+ from ray import tune
51
+ from ray.train import CheckpointConfig, RunConfig, ScalingConfig
52
+ from ray.train.lightning import (
53
+ RayDDPStrategy,
54
+ RayLightningEnvironment,
55
+ RayTrainReportCallback,
56
+ prepare_trainer,
57
+ )
58
+ from ray.train.torch import TorchTrainer
59
+ from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
60
+
61
+ DEFAULT_SEARCH_SPACE = {
62
+ "activation": tune.choice(categories=list(Activation.keys())),
63
+ "aggregation": tune.choice(categories=list(AggregationRegistry.keys())),
64
+ "aggregation_norm": tune.quniform(lower=1, upper=200, q=1),
65
+ "batch_size": tune.choice([16, 32, 64, 128, 256]),
66
+ "depth": tune.qrandint(lower=2, upper=6, q=1),
67
+ "dropout": tune.choice([0.0] * 8 + list(np.arange(0.05, 0.45, 0.05))),
68
+ "ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
69
+ "ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
70
+ "final_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
71
+ "message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
72
+ "init_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
73
+ "max_lr": tune.loguniform(lower=1e-4, upper=1e-2),
74
+ "warmup_epochs": None,
75
+ }
76
+ except ImportError:
77
+ NO_RAY = True
78
+
79
+ NO_HYPEROPT = False
80
+ try:
81
+ from ray.tune.search.hyperopt import HyperOptSearch
82
+ except ImportError:
83
+ NO_HYPEROPT = True
84
+
85
+ NO_OPTUNA = False
86
+ try:
87
+ from ray.tune.search.optuna import OptunaSearch
88
+ except ImportError:
89
+ NO_OPTUNA = True
90
+
91
+
92
+ logger = logging.getLogger(__name__)
93
+
94
+ SEARCH_SPACE = DEFAULT_SEARCH_SPACE
95
+
96
+ SEARCH_PARAM_KEYWORDS_MAP = {
97
+ "basic": ["depth", "ffn_num_layers", "dropout", "ffn_hidden_dim", "message_hidden_dim"],
98
+ "learning_rate": ["max_lr", "init_lr_ratio", "final_lr_ratio", "warmup_epochs"],
99
+ "all": list(DEFAULT_SEARCH_SPACE.keys()),
100
+ "init_lr": ["init_lr_ratio"],
101
+ "final_lr": ["final_lr_ratio"],
102
+ }
103
+
104
+
105
+ class HpoptSubcommand(Subcommand):
106
+ COMMAND = "hpopt"
107
+ HELP = "Perform hyperparameter optimization on the given task."
108
+
109
+ @classmethod
110
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
111
+ parser = add_common_args(parser)
112
+ parser = add_train_args(parser)
113
+ return add_hpopt_args(parser)
114
+
115
+ @classmethod
116
+ def func(cls, args: Namespace):
117
+ args = process_common_args(args)
118
+ args = process_train_args(args)
119
+ args = process_hpopt_args(args)
120
+ validate_common_args(args)
121
+ validate_train_args(args)
122
+ main(args)
123
+
124
+
125
+ def add_hpopt_args(parser: ArgumentParser) -> ArgumentParser:
126
+ hpopt_args = parser.add_argument_group("Chemprop hyperparameter optimization arguments")
127
+
128
+ hpopt_args.add_argument(
129
+ "--search-parameter-keywords",
130
+ type=str,
131
+ nargs="+",
132
+ default=["basic"],
133
+ help=f"""The model parameters over which to search for an optimal hyperparameter configuration. Some options are bundles of parameters or otherwise special parameter operations. Special keywords include:
134
+ - ``basic``: Default set of hyperparameters for search (depth, ffn_num_layers, dropout, message_hidden_dim, and ffn_hidden_dim)
135
+ - ``learning_rate``: Search for max_lr, init_lr_ratio, final_lr_ratio, and warmup_epochs. The search for init_lr and final_lr values are defined as fractions of the max_lr value. The search for warmup_epochs is as a fraction of the total epochs used.
136
+ - ``all``: Include search for all 13 individual keyword options (including: activation, aggregation, aggregation_norm, and batch_size which aren't included in the other two keywords).
137
+ Individual supported parameters:
138
+ {list(DEFAULT_SEARCH_SPACE.keys())}
139
+ """,
140
+ )
141
+
142
+ hpopt_args.add_argument(
143
+ "--hpopt-save-dir",
144
+ type=Path,
145
+ help="Directory to save the hyperparameter optimization results",
146
+ )
147
+
148
+ raytune_args = parser.add_argument_group("Ray Tune arguments")
149
+
150
+ raytune_args.add_argument(
151
+ "--raytune-num-samples",
152
+ type=int,
153
+ default=10,
154
+ help="Passed directly to Ray Tune ``TuneConfig`` to control number of trials to run",
155
+ )
156
+
157
+ raytune_args.add_argument(
158
+ "--raytune-search-algorithm",
159
+ choices=["random", "hyperopt", "optuna"],
160
+ default="hyperopt",
161
+ help="Passed to Ray Tune ``TuneConfig`` to control search algorithm",
162
+ )
163
+
164
+ raytune_args.add_argument(
165
+ "--raytune-trial-scheduler",
166
+ choices=["FIFO", "AsyncHyperBand"],
167
+ default="FIFO",
168
+ help="Passed to Ray Tune ``TuneConfig`` to control trial scheduler",
169
+ )
170
+
171
+ raytune_args.add_argument(
172
+ "--raytune-num-workers",
173
+ type=int,
174
+ default=1,
175
+ help="Passed directly to Ray Tune ``ScalingConfig`` to control number of workers to use",
176
+ )
177
+
178
+ raytune_args.add_argument(
179
+ "--raytune-use-gpu",
180
+ action="store_true",
181
+ help="Passed directly to Ray Tune ``ScalingConfig`` to control whether to use GPUs",
182
+ )
183
+
184
+ raytune_args.add_argument(
185
+ "--raytune-num-checkpoints-to-keep",
186
+ type=int,
187
+ default=1,
188
+ help="Passed directly to Ray Tune ``CheckpointConfig`` to control number of checkpoints to keep",
189
+ )
190
+
191
+ raytune_args.add_argument(
192
+ "--raytune-grace-period",
193
+ type=int,
194
+ default=10,
195
+ help="Passed directly to Ray Tune ``ASHAScheduler`` to control grace period",
196
+ )
197
+
198
+ raytune_args.add_argument(
199
+ "--raytune-reduction-factor",
200
+ type=int,
201
+ default=2,
202
+ help="Passed directly to Ray Tune ``ASHAScheduler`` to control reduction factor",
203
+ )
204
+
205
+ raytune_args.add_argument(
206
+ "--raytune-temp-dir", help="Passed directly to Ray Tune init to control temporary directory"
207
+ )
208
+
209
+ raytune_args.add_argument(
210
+ "--raytune-num-cpus",
211
+ type=int,
212
+ help="Passed directly to Ray Tune init to control number of CPUs to use",
213
+ )
214
+
215
+ raytune_args.add_argument(
216
+ "--raytune-num-gpus",
217
+ type=int,
218
+ help="Passed directly to Ray Tune init to control number of GPUs to use",
219
+ )
220
+
221
+ raytune_args.add_argument(
222
+ "--raytune-max-concurrent-trials",
223
+ type=int,
224
+ help="Passed directly to Ray Tune TuneConfig to control maximum concurrent trials",
225
+ )
226
+
227
+ hyperopt_args = parser.add_argument_group("Hyperopt arguments")
228
+
229
+ hyperopt_args.add_argument(
230
+ "--hyperopt-n-initial-points",
231
+ type=int,
232
+ help="Passed directly to ``HyperOptSearch`` to control number of initial points to sample",
233
+ )
234
+
235
+ hyperopt_args.add_argument(
236
+ "--hyperopt-random-state-seed",
237
+ type=int,
238
+ default=None,
239
+ help="Passed directly to ``HyperOptSearch`` to control random state seed",
240
+ )
241
+
242
+ return parser
243
+
244
+
245
+ def process_hpopt_args(args: Namespace) -> Namespace:
246
+ if args.hpopt_save_dir is None:
247
+ args.hpopt_save_dir = Path(f"chemprop_hpopt/{args.data_path.stem}")
248
+
249
+ args.hpopt_save_dir.mkdir(exist_ok=True, parents=True)
250
+
251
+ search_parameters = set()
252
+
253
+ available_search_parameters = list(SEARCH_SPACE.keys()) + list(SEARCH_PARAM_KEYWORDS_MAP.keys())
254
+
255
+ for keyword in args.search_parameter_keywords:
256
+ if keyword not in available_search_parameters:
257
+ raise ValueError(
258
+ f"Search parameter keyword: {keyword} not in available options: {available_search_parameters}."
259
+ )
260
+
261
+ search_parameters.update(
262
+ SEARCH_PARAM_KEYWORDS_MAP[keyword]
263
+ if keyword in SEARCH_PARAM_KEYWORDS_MAP
264
+ else [keyword]
265
+ )
266
+
267
+ args.search_parameter_keywords = list(search_parameters)
268
+
269
+ if not args.hyperopt_n_initial_points:
270
+ args.hyperopt_n_initial_points = args.raytune_num_samples // 2
271
+
272
+ return args
273
+
274
+
275
+ def build_search_space(search_parameters: list[str], train_epochs: int) -> dict:
276
+ if "warmup_epochs" in search_parameters and SEARCH_SPACE.get("warmup_epochs", None) is None:
277
+ assert (
278
+ train_epochs >= 6
279
+ ), "Training epochs must be at least 6 to perform hyperparameter optimization for warmup_epochs."
280
+ SEARCH_SPACE["warmup_epochs"] = tune.qrandint(lower=1, upper=train_epochs // 2, q=1)
281
+
282
+ return {param: SEARCH_SPACE[param] for param in search_parameters}
283
+
284
+
285
+ def update_args_with_config(args: Namespace, config: dict) -> Namespace:
286
+ args = deepcopy(args)
287
+
288
+ for key, value in config.items():
289
+ match key:
290
+ case "final_lr_ratio":
291
+ setattr(args, "final_lr", value * config.get("max_lr", args.max_lr))
292
+
293
+ case "init_lr_ratio":
294
+ setattr(args, "init_lr", value * config.get("max_lr", args.max_lr))
295
+
296
+ case _:
297
+ assert key in args, f"Key: {key} not found in args."
298
+ setattr(args, key, value)
299
+
300
+ return args
301
+
302
+
303
+ def train_model(config, args, train_dset, val_dset, logger, output_transform, input_transforms):
304
+ args = update_args_with_config(args, config)
305
+
306
+ train_loader = build_dataloader(
307
+ train_dset, args.batch_size, args.num_workers, seed=args.data_seed
308
+ )
309
+ val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
310
+
311
+ seed = args.pytorch_seed if args.pytorch_seed is not None else torch.seed()
312
+
313
+ torch.manual_seed(seed)
314
+
315
+ model = build_model(args, train_loader.dataset, output_transform, input_transforms)
316
+ logger.info(model)
317
+
318
+ if args.tracking_metric == "val_loss":
319
+ T_tracking_metric = model.criterion.__class__
320
+ else:
321
+ T_tracking_metric = MetricRegistry[args.tracking_metric]
322
+ args.tracking_metric = "val/" + args.tracking_metric
323
+
324
+ monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
325
+ logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
326
+
327
+ patience = args.patience if args.patience is not None else args.epochs
328
+ early_stopping = EarlyStopping(args.tracking_metric, patience=patience, mode=monitor_mode)
329
+
330
+ trainer = pl.Trainer(
331
+ accelerator=args.accelerator,
332
+ devices=args.devices,
333
+ max_epochs=args.epochs,
334
+ gradient_clip_val=args.grad_clip,
335
+ strategy=RayDDPStrategy(),
336
+ callbacks=[RayTrainReportCallback(), early_stopping],
337
+ plugins=[RayLightningEnvironment()],
338
+ deterministic=args.pytorch_seed is not None,
339
+ )
340
+ trainer = prepare_trainer(trainer)
341
+ trainer.fit(model, train_loader, val_loader)
342
+
343
+
344
+ def tune_model(
345
+ args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
346
+ ):
347
+ match args.raytune_trial_scheduler:
348
+ case "FIFO":
349
+ scheduler = FIFOScheduler()
350
+ case "AsyncHyperBand":
351
+ scheduler = ASHAScheduler(
352
+ max_t=args.epochs,
353
+ grace_period=min(args.raytune_grace_period, args.epochs),
354
+ reduction_factor=args.raytune_reduction_factor,
355
+ )
356
+ case _:
357
+ raise ValueError(f"Invalid trial scheduler! got: {args.raytune_trial_scheduler}.")
358
+
359
+ resources_per_worker = {}
360
+ if args.raytune_num_cpus and args.raytune_max_concurrent_trials:
361
+ resources_per_worker["CPU"] = args.raytune_num_cpus / args.raytune_max_concurrent_trials
362
+ if args.raytune_num_gpus and args.raytune_max_concurrent_trials:
363
+ resources_per_worker["GPU"] = args.raytune_num_gpus / args.raytune_max_concurrent_trials
364
+ if not resources_per_worker:
365
+ resources_per_worker = None
366
+
367
+ if args.raytune_num_gpus:
368
+ use_gpu = True
369
+ else:
370
+ use_gpu = args.raytune_use_gpu
371
+
372
+ scaling_config = ScalingConfig(
373
+ num_workers=args.raytune_num_workers,
374
+ use_gpu=use_gpu,
375
+ resources_per_worker=resources_per_worker,
376
+ trainer_resources={"CPU": 0},
377
+ )
378
+
379
+ checkpoint_config = CheckpointConfig(
380
+ num_to_keep=args.raytune_num_checkpoints_to_keep,
381
+ checkpoint_score_attribute=args.tracking_metric,
382
+ checkpoint_score_order=monitor_mode,
383
+ )
384
+
385
+ run_config = RunConfig(
386
+ checkpoint_config=checkpoint_config,
387
+ storage_path=args.hpopt_save_dir.absolute() / "ray_results",
388
+ )
389
+
390
+ ray_trainer = TorchTrainer(
391
+ lambda config: train_model(
392
+ config, args, train_dset, val_dset, logger, output_transform, input_transforms
393
+ ),
394
+ scaling_config=scaling_config,
395
+ run_config=run_config,
396
+ )
397
+
398
+ match args.raytune_search_algorithm:
399
+ case "random":
400
+ search_alg = None
401
+ case "hyperopt":
402
+ if NO_HYPEROPT:
403
+ raise ImportError(
404
+ "HyperOptSearch requires hyperopt to be installed. Use 'pip install -U hyperopt' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
405
+ )
406
+
407
+ search_alg = HyperOptSearch(
408
+ n_initial_points=args.hyperopt_n_initial_points,
409
+ random_state_seed=args.hyperopt_random_state_seed,
410
+ )
411
+ case "optuna":
412
+ if NO_OPTUNA:
413
+ raise ImportError(
414
+ "OptunaSearch requires optuna to be installed. Use 'pip install -U optuna' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
415
+ )
416
+
417
+ search_alg = OptunaSearch()
418
+
419
+ tune_config = tune.TuneConfig(
420
+ metric=args.tracking_metric,
421
+ mode=monitor_mode,
422
+ num_samples=args.raytune_num_samples,
423
+ scheduler=scheduler,
424
+ search_alg=search_alg,
425
+ trial_dirname_creator=lambda trial: str(trial.trial_id),
426
+ )
427
+
428
+ tuner = tune.Tuner(
429
+ ray_trainer,
430
+ param_space={
431
+ "train_loop_config": build_search_space(args.search_parameter_keywords, args.epochs)
432
+ },
433
+ tune_config=tune_config,
434
+ )
435
+
436
+ return tuner.fit()
437
+
438
+
439
+ def main(args: Namespace):
440
+ if NO_RAY:
441
+ raise ImportError(
442
+ "Ray Tune requires ray to be installed. If you installed Chemprop from PyPI, run 'pip install -U ray[tune]' to install ray. If you installed from source, use 'pip install -e .[hpopt]' in Chemprop folder to install all hpopt relevant packages."
443
+ )
444
+
445
+ if not ray.is_initialized():
446
+ try:
447
+ ray.init(
448
+ _temp_dir=args.raytune_temp_dir,
449
+ num_cpus=args.raytune_num_cpus,
450
+ num_gpus=args.raytune_num_gpus,
451
+ )
452
+ except OSError as e:
453
+ if "AF_UNIX path length cannot exceed 107 bytes" in str(e):
454
+ raise OSError(
455
+ f"Ray Tune fails due to: {e}. This can sometimes be solved by providing a temporary directory, num_cpus, and num_gpus to Ray Tune via the CLI: --raytune-temp-dir <absolute_path> --raytune-num-cpus <int> --raytune-num-gpus <int>."
456
+ )
457
+ else:
458
+ raise e
459
+ else:
460
+ logger.info("Ray is already initialized.")
461
+
462
+ format_kwargs = dict(
463
+ no_header_row=args.no_header_row,
464
+ smiles_cols=args.smiles_columns,
465
+ rxn_cols=args.reaction_columns,
466
+ target_cols=args.target_columns,
467
+ ignore_cols=args.ignore_columns,
468
+ splits_col=args.splits_column,
469
+ weight_col=args.weight_column,
470
+ bounded=args.loss_function is not None and "bounded" in args.loss_function,
471
+ )
472
+
473
+ featurization_kwargs = dict(
474
+ molecule_featurizers=args.molecule_featurizers,
475
+ keep_h=args.keep_h,
476
+ add_h=args.add_h,
477
+ ignore_chirality=args.ignore_chirality,
478
+ )
479
+
480
+ train_data, val_data, test_data = build_splits(args, format_kwargs, featurization_kwargs)
481
+ train_dset, val_dset, test_dset = build_datasets(args, train_data[0], val_data[0], test_data[0])
482
+
483
+ input_transforms = normalize_inputs(train_dset, val_dset, args)
484
+
485
+ if "regression" in args.task_type:
486
+ output_scaler = train_dset.normalize_targets()
487
+ val_dset.normalize_targets(output_scaler)
488
+ logger.info(f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}")
489
+ output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
490
+ else:
491
+ output_transform = None
492
+
493
+ train_loader = build_dataloader(
494
+ train_dset, args.batch_size, args.num_workers, seed=args.data_seed
495
+ )
496
+
497
+ model = build_model(args, train_loader.dataset, output_transform, input_transforms)
498
+ monitor_mode = "max" if model.metrics[0].higher_is_better else "min"
499
+
500
+ results = tune_model(
501
+ args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
502
+ )
503
+
504
+ best_result = results.get_best_result()
505
+ best_config = best_result.config["train_loop_config"]
506
+ best_checkpoint_path = Path(best_result.checkpoint.path) / "checkpoint.ckpt"
507
+
508
+ best_config_save_path = args.hpopt_save_dir / "best_config.toml"
509
+ best_checkpoint_save_path = args.hpopt_save_dir / "best_checkpoint.ckpt"
510
+ all_progress_save_path = args.hpopt_save_dir / "all_progress.csv"
511
+
512
+ logger.info(f"Best hyperparameters saved to: '{best_config_save_path}'")
513
+
514
+ args = update_args_with_config(args, best_config)
515
+
516
+ args = TrainSubcommand.parser.parse_known_args(namespace=args)[0]
517
+ save_config(TrainSubcommand.parser, args, best_config_save_path)
518
+
519
+ logger.info(
520
+ f"Best hyperparameter configuration checkpoint saved to '{best_checkpoint_save_path}'"
521
+ )
522
+
523
+ shutil.copyfile(best_checkpoint_path, best_checkpoint_save_path)
524
+
525
+ logger.info(f"Hyperparameter optimization results saved to '{all_progress_save_path}'")
526
+
527
+ result_df = results.get_dataframe()
528
+
529
+ result_df.to_csv(all_progress_save_path, index=False)
530
+
531
+ ray.shutdown()
532
+
533
+
534
+ if __name__ == "__main__":
535
+ parser = ArgumentParser()
536
+ parser = HpoptSubcommand.add_args(parser)
537
+
538
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
539
+ args = parser.parse_args()
540
+ HpoptSubcommand.func(args)
chemprop-updated/chemprop/cli/main.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ import sys
4
+
5
+ from configargparse import ArgumentParser
6
+
7
+ from chemprop.cli.conf import LOG_DIR, LOG_LEVELS, NOW
8
+ from chemprop.cli.convert import ConvertSubcommand
9
+ from chemprop.cli.fingerprint import FingerprintSubcommand
10
+ from chemprop.cli.hpopt import HpoptSubcommand
11
+ from chemprop.cli.predict import PredictSubcommand
12
+ from chemprop.cli.train import TrainSubcommand
13
+ from chemprop.cli.utils import pop_attr
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ SUBCOMMANDS = [
18
+ TrainSubcommand,
19
+ PredictSubcommand,
20
+ ConvertSubcommand,
21
+ FingerprintSubcommand,
22
+ HpoptSubcommand,
23
+ ]
24
+
25
+
26
+ def construct_parser():
27
+ parser = ArgumentParser()
28
+ subparsers = parser.add_subparsers(title="mode", dest="mode", required=True)
29
+
30
+ parent = ArgumentParser(add_help=False)
31
+ parent.add_argument(
32
+ "--logfile",
33
+ "--log",
34
+ nargs="?",
35
+ const="default",
36
+ help=f"Path to which the log file should be written (specifying just the flag alone will automatically log to a file ``{LOG_DIR}/MODE/TIMESTAMP.log`` , where 'MODE' is the CLI mode chosen, e.g., ``{LOG_DIR}/MODE/{NOW}.log``)",
37
+ )
38
+ parent.add_argument("-v", action="store_true", help="Increase verbosity level to DEBUG")
39
+ parent.add_argument(
40
+ "-q",
41
+ action="count",
42
+ default=0,
43
+ help="Decrease verbosity level to WARNING or ERROR if specified twice",
44
+ )
45
+
46
+ parents = [parent]
47
+ for subcommand in SUBCOMMANDS:
48
+ subcommand.add(subparsers, parents)
49
+
50
+ return parser
51
+
52
+
53
+ def main():
54
+ parser = construct_parser()
55
+ args = parser.parse_args()
56
+ logfile, v_flag, q_count, mode, func = (
57
+ pop_attr(args, attr) for attr in ["logfile", "v", "q", "mode", "func"]
58
+ )
59
+
60
+ if v_flag and q_count:
61
+ parser.error("The -v and -q options cannot be used together.")
62
+
63
+ match logfile:
64
+ case None:
65
+ handler = logging.StreamHandler(sys.stderr)
66
+ case "default":
67
+ (LOG_DIR / mode).mkdir(parents=True, exist_ok=True)
68
+ handler = logging.FileHandler(str(LOG_DIR / mode / f"{NOW}.log"))
69
+ case _:
70
+ Path(logfile).parent.mkdir(parents=True, exist_ok=True)
71
+ handler = logging.FileHandler(logfile)
72
+
73
+ verbosity = q_count * -1 if q_count else (1 if v_flag else 0)
74
+ logging_level = LOG_LEVELS.get(verbosity, logging.ERROR)
75
+ logging.basicConfig(
76
+ handlers=[handler],
77
+ format="%(asctime)s - %(levelname)s:%(name)s - %(message)s",
78
+ level=logging_level,
79
+ datefmt="%Y-%m-%dT%H:%M:%S",
80
+ force=True,
81
+ )
82
+
83
+ logger.info(f"Running in mode '{mode}' with args: {vars(args)}")
84
+
85
+ func(args)
chemprop-updated/chemprop/cli/predict.py ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+ import sys
5
+ from typing import Iterator
6
+
7
+ from lightning import pytorch as pl
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+
12
+ from chemprop import data
13
+ from chemprop.cli.common import (
14
+ add_common_args,
15
+ find_models,
16
+ process_common_args,
17
+ validate_common_args,
18
+ )
19
+ from chemprop.cli.utils import LookupAction, Subcommand, build_data_from_files, make_dataset
20
+ from chemprop.models.utils import load_model, load_output_columns
21
+ from chemprop.nn.metrics import LossFunctionRegistry
22
+ from chemprop.nn.predictors import EvidentialFFN, MulticlassClassificationFFN, MveFFN
23
+ from chemprop.uncertainty import (
24
+ MVEWeightingCalibrator,
25
+ NoUncertaintyEstimator,
26
+ RegressionCalibrator,
27
+ RegressionEvaluator,
28
+ UncertaintyCalibratorRegistry,
29
+ UncertaintyEstimatorRegistry,
30
+ UncertaintyEvaluatorRegistry,
31
+ )
32
+ from chemprop.utils import Factory
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class PredictSubcommand(Subcommand):
38
+ COMMAND = "predict"
39
+ HELP = "use a pretrained chemprop model for prediction"
40
+
41
+ @classmethod
42
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
43
+ parser = add_common_args(parser)
44
+ return add_predict_args(parser)
45
+
46
+ @classmethod
47
+ def func(cls, args: Namespace):
48
+ args = process_common_args(args)
49
+ validate_common_args(args)
50
+ args = process_predict_args(args)
51
+ main(args)
52
+
53
+
54
+ def add_predict_args(parser: ArgumentParser) -> ArgumentParser:
55
+ parser.add_argument(
56
+ "-i",
57
+ "--test-path",
58
+ required=True,
59
+ type=Path,
60
+ help="Path to an input CSV file containing SMILES",
61
+ )
62
+ parser.add_argument(
63
+ "-o",
64
+ "--output",
65
+ "--preds-path",
66
+ type=Path,
67
+ help="Specify path to which predictions will be saved. If the file extension is .pkl, it will be saved as a pickle file. Otherwise, chemprop will save predictions as a CSV. If multiple models are used to make predictions, the average predictions will be saved in the file, and another file ending in '_individual' with the same file extension will save the predictions for each individual model, with the column names being the target names appended with the model index (e.g., '_model_<index>').",
68
+ )
69
+ parser.add_argument(
70
+ "--drop-extra-columns",
71
+ action="store_true",
72
+ help="Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns",
73
+ )
74
+ parser.add_argument(
75
+ "--model-paths",
76
+ "--model-path",
77
+ required=True,
78
+ type=Path,
79
+ nargs="+",
80
+ help="Location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, will recursively search and predict on all found (.pt) models.",
81
+ )
82
+
83
+ unc_args = parser.add_argument_group("Uncertainty and calibration args")
84
+ unc_args.add_argument(
85
+ "--cal-path", type=Path, help="Path to data file to be used for uncertainty calibration."
86
+ )
87
+ unc_args.add_argument(
88
+ "--uncertainty-method",
89
+ default="none",
90
+ action=LookupAction(UncertaintyEstimatorRegistry),
91
+ help="The method of calculating uncertainty.",
92
+ )
93
+ unc_args.add_argument(
94
+ "--calibration-method",
95
+ action=LookupAction(UncertaintyCalibratorRegistry),
96
+ help="The method used for calibrating the uncertainty calculated with uncertainty method.",
97
+ )
98
+ unc_args.add_argument(
99
+ "--evaluation-methods",
100
+ "--evaluation-method",
101
+ nargs="+",
102
+ action=LookupAction(UncertaintyEvaluatorRegistry),
103
+ help="The methods used for evaluating the uncertainty performance if the test data provided includes targets. Available methods are [nll, miscalibration_area, ence, spearman] or any available classification or multiclass metric.",
104
+ )
105
+ # unc_args.add_argument(
106
+ # "--evaluation-scores-path", help="Location to save the results of uncertainty evaluations."
107
+ # )
108
+ unc_args.add_argument(
109
+ "--uncertainty-dropout-p",
110
+ type=float,
111
+ default=0.1,
112
+ help="The probability to use for Monte Carlo dropout uncertainty estimation.",
113
+ )
114
+ unc_args.add_argument(
115
+ "--dropout-sampling-size",
116
+ type=int,
117
+ default=10,
118
+ help="The number of samples to use for Monte Carlo dropout uncertainty estimation. Distinct from the dropout used during training.",
119
+ )
120
+ unc_args.add_argument(
121
+ "--calibration-interval-percentile",
122
+ type=float,
123
+ default=95,
124
+ help="Sets the percentile used in the calibration methods. Must be in the range (1, 100).",
125
+ )
126
+ unc_args.add_argument(
127
+ "--conformal-alpha",
128
+ type=float,
129
+ default=0.1,
130
+ help="Target error rate for conformal prediction. Must be in the range (0, 1).",
131
+ )
132
+ # TODO: Decide if we want to implment this in v2.1.x
133
+ # unc_args.add_argument(
134
+ # "--regression-calibrator-metric",
135
+ # choices=["stdev", "interval"],
136
+ # help="Regression calibrators can output either a stdev or an inverval.",
137
+ # )
138
+ unc_args.add_argument(
139
+ "--cal-descriptors-path",
140
+ nargs="+",
141
+ action="append",
142
+ help="Path to extra descriptors to concatenate to learned representation in calibration dataset.",
143
+ )
144
+ # TODO: Add in v2.1.x
145
+ # unc_args.add_argument(
146
+ # "--calibration-phase-features-path",
147
+ # help=" ",
148
+ # )
149
+ unc_args.add_argument(
150
+ "--cal-atom-features-path",
151
+ nargs="+",
152
+ action="append",
153
+ help="Path to the extra atom features in calibration dataset.",
154
+ )
155
+ unc_args.add_argument(
156
+ "--cal-atom-descriptors-path",
157
+ nargs="+",
158
+ action="append",
159
+ help="Path to the extra atom descriptors in calibration dataset.",
160
+ )
161
+ unc_args.add_argument(
162
+ "--cal-bond-features-path",
163
+ nargs="+",
164
+ action="append",
165
+ help="Path to the extra bond descriptors in calibration dataset.",
166
+ )
167
+
168
+ return parser
169
+
170
+
171
+ def process_predict_args(args: Namespace) -> Namespace:
172
+ if args.test_path.suffix not in [".csv"]:
173
+ raise ArgumentError(
174
+ argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
175
+ )
176
+ if args.output is None:
177
+ args.output = args.test_path.parent / (args.test_path.stem + "_preds.csv")
178
+ if args.output.suffix not in [".csv", ".pkl"]:
179
+ raise ArgumentError(
180
+ argument=None, message=f"Output must be a CSV or Pickle file. Got {args.output}"
181
+ )
182
+ return args
183
+
184
+
185
+ def prepare_data_loader(
186
+ args: Namespace, multicomponent: bool, is_calibration: bool, format_kwargs: dict
187
+ ):
188
+ data_path = args.cal_path if is_calibration else args.test_path
189
+ descriptors_path = args.cal_descriptors_path if is_calibration else args.descriptors_path
190
+ atom_feats_path = args.cal_atom_features_path if is_calibration else args.atom_features_path
191
+ bond_feats_path = args.cal_bond_features_path if is_calibration else args.bond_features_path
192
+ atom_descs_path = (
193
+ args.cal_atom_descriptors_path if is_calibration else args.atom_descriptors_path
194
+ )
195
+
196
+ featurization_kwargs = dict(
197
+ molecule_featurizers=args.molecule_featurizers,
198
+ keep_h=args.keep_h,
199
+ add_h=args.add_h,
200
+ ignore_chirality=args.ignore_chirality,
201
+ )
202
+
203
+ datas = build_data_from_files(
204
+ data_path,
205
+ **format_kwargs,
206
+ p_descriptors=descriptors_path,
207
+ p_atom_feats=atom_feats_path,
208
+ p_bond_feats=bond_feats_path,
209
+ p_atom_descs=atom_descs_path,
210
+ **featurization_kwargs,
211
+ )
212
+
213
+ dsets = [make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in datas]
214
+ dset = data.MulticomponentDataset(dsets) if multicomponent else dsets[0]
215
+
216
+ return data.build_dataloader(dset, args.batch_size, args.num_workers, shuffle=False)
217
+
218
+
219
+ def make_prediction_for_models(
220
+ args: Namespace, model_paths: Iterator[Path], multicomponent: bool, output_path: Path
221
+ ):
222
+ model = load_model(model_paths[0], multicomponent)
223
+ output_columns = load_output_columns(model_paths[0])
224
+ bounded = any(
225
+ isinstance(model.criterion, LossFunctionRegistry[loss_function])
226
+ for loss_function in LossFunctionRegistry.keys()
227
+ if "bounded" in loss_function
228
+ )
229
+ format_kwargs = dict(
230
+ no_header_row=args.no_header_row,
231
+ smiles_cols=args.smiles_columns,
232
+ rxn_cols=args.reaction_columns,
233
+ ignore_cols=None,
234
+ splits_col=None,
235
+ weight_col=None,
236
+ bounded=bounded,
237
+ )
238
+ format_kwargs["target_cols"] = output_columns if args.evaluation_methods is not None else []
239
+ test_loader = prepare_data_loader(args, multicomponent, False, format_kwargs)
240
+ logger.info(f"test size: {len(test_loader.dataset)}")
241
+ if args.cal_path is not None:
242
+ format_kwargs["target_cols"] = output_columns
243
+ cal_loader = prepare_data_loader(args, multicomponent, True, format_kwargs)
244
+ logger.info(f"calibration size: {len(cal_loader.dataset)}")
245
+
246
+ uncertainty_estimator = Factory.build(
247
+ UncertaintyEstimatorRegistry[args.uncertainty_method],
248
+ ensemble_size=args.dropout_sampling_size,
249
+ dropout=args.uncertainty_dropout_p,
250
+ )
251
+
252
+ models = [load_model(model_path, multicomponent) for model_path in model_paths]
253
+ trainer = pl.Trainer(
254
+ logger=False, enable_progress_bar=True, accelerator=args.accelerator, devices=args.devices
255
+ )
256
+ test_individual_preds, test_individual_uncs = uncertainty_estimator(
257
+ test_loader, models, trainer
258
+ )
259
+ test_preds = torch.mean(test_individual_preds, dim=0)
260
+ if not isinstance(uncertainty_estimator, NoUncertaintyEstimator):
261
+ test_uncs = torch.mean(test_individual_uncs, dim=0)
262
+ else:
263
+ test_uncs = None
264
+
265
+ if args.calibration_method is not None:
266
+ uncertainty_calibrator = Factory.build(
267
+ UncertaintyCalibratorRegistry[args.calibration_method],
268
+ p=args.calibration_interval_percentile / 100,
269
+ alpha=args.conformal_alpha,
270
+ )
271
+ cal_targets = cal_loader.dataset.Y
272
+ cal_mask = torch.from_numpy(np.isfinite(cal_targets))
273
+ cal_targets = np.nan_to_num(cal_targets, nan=0.0)
274
+ cal_targets = torch.from_numpy(cal_targets)
275
+ cal_individual_preds, cal_individual_uncs = uncertainty_estimator(
276
+ cal_loader, models, trainer
277
+ )
278
+ cal_preds = torch.mean(cal_individual_preds, dim=0)
279
+ cal_uncs = torch.mean(cal_individual_uncs, dim=0)
280
+ if isinstance(uncertainty_calibrator, MVEWeightingCalibrator):
281
+ uncertainty_calibrator.fit(cal_preds, cal_individual_uncs, cal_targets, cal_mask)
282
+ test_uncs = uncertainty_calibrator.apply(cal_individual_uncs)
283
+ else:
284
+ if isinstance(uncertainty_calibrator, RegressionCalibrator):
285
+ uncertainty_calibrator.fit(cal_preds, cal_uncs, cal_targets, cal_mask)
286
+ else:
287
+ uncertainty_calibrator.fit(cal_uncs, cal_targets, cal_mask)
288
+ test_uncs = uncertainty_calibrator.apply(test_uncs)
289
+ for i in range(test_individual_uncs.shape[0]):
290
+ test_individual_uncs[i] = uncertainty_calibrator.apply(test_individual_uncs[i])
291
+
292
+ if args.evaluation_methods is not None:
293
+ uncertainty_evaluators = [
294
+ Factory.build(UncertaintyEvaluatorRegistry[method])
295
+ for method in args.evaluation_methods
296
+ ]
297
+ logger.info("Uncertainty evaluation metric:")
298
+ for evaluator in uncertainty_evaluators:
299
+ test_targets = test_loader.dataset.Y
300
+ test_mask = torch.from_numpy(np.isfinite(test_targets))
301
+ test_targets = np.nan_to_num(test_targets, nan=0.0)
302
+ test_targets = torch.from_numpy(test_targets)
303
+ if isinstance(evaluator, RegressionEvaluator):
304
+ metric_value = evaluator.evaluate(test_preds, test_uncs, test_targets, test_mask)
305
+ else:
306
+ metric_value = evaluator.evaluate(test_uncs, test_targets, test_mask)
307
+ logger.info(f"{evaluator.alias}: {metric_value.tolist()}")
308
+
309
+ if args.uncertainty_method == "none" and (
310
+ isinstance(model.predictor, MveFFN) or isinstance(model.predictor, EvidentialFFN)
311
+ ):
312
+ test_preds = test_preds[..., 0]
313
+ test_individual_preds = test_individual_preds[..., 0]
314
+
315
+ if output_columns is None:
316
+ output_columns = [
317
+ f"pred_{i}" for i in range(test_preds.shape[1])
318
+ ] # TODO: need to improve this for cases like multi-task MVE and multi-task multiclass
319
+
320
+ save_predictions(args, model, output_columns, test_preds, test_uncs, output_path)
321
+
322
+ if len(model_paths) > 1:
323
+ save_individual_predictions(
324
+ args,
325
+ model,
326
+ model_paths,
327
+ output_columns,
328
+ test_individual_preds,
329
+ test_individual_uncs,
330
+ output_path,
331
+ )
332
+
333
+
334
+ def save_predictions(args, model, output_columns, test_preds, test_uncs, output_path):
335
+ unc_columns = [f"{col}_unc" for col in output_columns]
336
+
337
+ if isinstance(model.predictor, MulticlassClassificationFFN):
338
+ output_columns = output_columns + [f"{col}_prob" for col in output_columns]
339
+ predicted_class_labels = test_preds.argmax(axis=-1)
340
+ formatted_probability_strings = np.apply_along_axis(
341
+ lambda x: ",".join(map(str, x)), 2, test_preds.numpy()
342
+ )
343
+ test_preds = np.concatenate(
344
+ (predicted_class_labels, formatted_probability_strings), axis=-1
345
+ )
346
+
347
+ df_test = pd.read_csv(
348
+ args.test_path, header=None if args.no_header_row else "infer", index_col=False
349
+ )
350
+ df_test[output_columns] = test_preds
351
+
352
+ if args.uncertainty_method not in ["none", "classification"]:
353
+ df_test[unc_columns] = np.round(test_uncs, 6)
354
+
355
+ if output_path.suffix == ".pkl":
356
+ df_test = df_test.reset_index(drop=True)
357
+ df_test.to_pickle(output_path)
358
+ else:
359
+ df_test.to_csv(output_path, index=False)
360
+ logger.info(f"Predictions saved to '{output_path}'")
361
+
362
+
363
+ def save_individual_predictions(
364
+ args,
365
+ model,
366
+ model_paths,
367
+ output_columns,
368
+ test_individual_preds,
369
+ test_individual_uncs,
370
+ output_path,
371
+ ):
372
+ unc_columns = [
373
+ f"{col}_unc_model_{i}" for i in range(len(model_paths)) for col in output_columns
374
+ ]
375
+
376
+ if isinstance(model.predictor, MulticlassClassificationFFN):
377
+ output_columns = [
378
+ item
379
+ for i in range(len(model_paths))
380
+ for col in output_columns
381
+ for item in (f"{col}_model_{i}", f"{col}_prob_model_{i}")
382
+ ]
383
+
384
+ predicted_class_labels = test_individual_preds.argmax(axis=-1)
385
+ formatted_probability_strings = np.apply_along_axis(
386
+ lambda x: ",".join(map(str, x)), 3, test_individual_preds.numpy()
387
+ )
388
+ test_individual_preds = np.concatenate(
389
+ (predicted_class_labels, formatted_probability_strings), axis=-1
390
+ )
391
+ else:
392
+ output_columns = [
393
+ f"{col}_model_{i}" for i in range(len(model_paths)) for col in output_columns
394
+ ]
395
+
396
+ m, n, t = test_individual_preds.shape
397
+ test_individual_preds = np.transpose(test_individual_preds, (1, 0, 2)).reshape(n, m * t)
398
+ df_test = pd.read_csv(
399
+ args.test_path, header=None if args.no_header_row else "infer", index_col=False
400
+ )
401
+ df_test[output_columns] = test_individual_preds
402
+
403
+ if args.uncertainty_method not in ["none", "classification", "ensemble"]:
404
+ m, n, t = test_individual_uncs.shape
405
+ test_individual_uncs = np.transpose(test_individual_uncs, (1, 0, 2)).reshape(n, m * t)
406
+ df_test[unc_columns] = np.round(test_individual_uncs, 6)
407
+
408
+ output_path = output_path.parent / Path(
409
+ str(args.output.stem) + "_individual" + str(output_path.suffix)
410
+ )
411
+ if output_path.suffix == ".pkl":
412
+ df_test = df_test.reset_index(drop=True)
413
+ df_test.to_pickle(output_path)
414
+ else:
415
+ df_test.to_csv(output_path, index=False)
416
+ logger.info(f"Individual predictions saved to '{output_path}'")
417
+ for i, model_path in enumerate(model_paths):
418
+ logger.info(
419
+ f"Results from model path {model_path} are saved under the column name ending with 'model_{i}'"
420
+ )
421
+
422
+
423
+ def main(args):
424
+ match (args.smiles_columns, args.reaction_columns):
425
+ case [None, None]:
426
+ n_components = 1
427
+ case [_, None]:
428
+ n_components = len(args.smiles_columns)
429
+ case [None, _]:
430
+ n_components = len(args.reaction_columns)
431
+ case _:
432
+ n_components = len(args.smiles_columns) + len(args.reaction_columns)
433
+
434
+ multicomponent = n_components > 1
435
+
436
+ model_paths = find_models(args.model_paths)
437
+
438
+ make_prediction_for_models(args, model_paths, multicomponent, output_path=args.output)
439
+
440
+
441
+ if __name__ == "__main__":
442
+ parser = ArgumentParser()
443
+ parser = PredictSubcommand.add_args(parser)
444
+
445
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
446
+ args = parser.parse_args()
447
+ args = PredictSubcommand.func(args)
chemprop-updated/chemprop/cli/train.py ADDED
@@ -0,0 +1,1343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from io import StringIO
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ import sys
7
+ from tempfile import TemporaryDirectory
8
+
9
+ from configargparse import ArgumentError, ArgumentParser, Namespace
10
+ from lightning import pytorch as pl
11
+ from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
12
+ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
13
+ from lightning.pytorch.strategies import DDPStrategy
14
+ import numpy as np
15
+ import pandas as pd
16
+ from rich.console import Console
17
+ from rich.table import Column, Table
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from chemprop.cli.common import (
22
+ add_common_args,
23
+ find_models,
24
+ process_common_args,
25
+ validate_common_args,
26
+ )
27
+ from chemprop.cli.conf import CHEMPROP_TRAIN_DIR, NOW
28
+ from chemprop.cli.utils import (
29
+ LookupAction,
30
+ Subcommand,
31
+ build_data_from_files,
32
+ get_column_names,
33
+ make_dataset,
34
+ parse_indices,
35
+ )
36
+ from chemprop.cli.utils.args import uppercase
37
+ from chemprop.data import (
38
+ MoleculeDataset,
39
+ MolGraphDataset,
40
+ MulticomponentDataset,
41
+ ReactionDatapoint,
42
+ SplitType,
43
+ build_dataloader,
44
+ make_split_indices,
45
+ split_data_by_indices,
46
+ )
47
+ from chemprop.data.datasets import _MolGraphDatasetMixin
48
+ from chemprop.models import MPNN, MulticomponentMPNN, save_model
49
+ from chemprop.nn import AggregationRegistry, LossFunctionRegistry, MetricRegistry, PredictorRegistry
50
+ from chemprop.nn.message_passing import (
51
+ AtomMessagePassing,
52
+ BondMessagePassing,
53
+ MulticomponentMessagePassing,
54
+ )
55
+ from chemprop.nn.transforms import GraphTransform, ScaleTransform, UnscaleTransform
56
+ from chemprop.nn.utils import Activation
57
+ from chemprop.utils import Factory
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
62
+ _CV_REMOVAL_ERROR = (
63
+ "The -k/--num-folds argument was removed in v2.1.0 - use --num-replicates instead."
64
+ )
65
+
66
+
67
+ class TrainSubcommand(Subcommand):
68
+ COMMAND = "train"
69
+ HELP = "Train a chemprop model."
70
+ parser = None
71
+
72
+ @classmethod
73
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
74
+ parser = add_common_args(parser)
75
+ parser = add_train_args(parser)
76
+ cls.parser = parser
77
+ return parser
78
+
79
+ @classmethod
80
+ def func(cls, args: Namespace):
81
+ args = process_common_args(args)
82
+ validate_common_args(args)
83
+ args = process_train_args(args)
84
+ validate_train_args(args)
85
+
86
+ args.output_dir.mkdir(exist_ok=True, parents=True)
87
+ config_path = args.output_dir / "config.toml"
88
+ save_config(cls.parser, args, config_path)
89
+ main(args)
90
+
91
+
92
+ def add_train_args(parser: ArgumentParser) -> ArgumentParser:
93
+ parser.add_argument(
94
+ "--config-path",
95
+ type=Path,
96
+ is_config_file=True,
97
+ help="Path to a configuration file (command line arguments override values in the configuration file)",
98
+ )
99
+ parser.add_argument(
100
+ "-i",
101
+ "--data-path",
102
+ type=Path,
103
+ help="Path to an input CSV file containing SMILES and the associated target values",
104
+ )
105
+ parser.add_argument(
106
+ "-o",
107
+ "--output-dir",
108
+ "--save-dir",
109
+ type=Path,
110
+ help="Directory where training outputs will be saved (defaults to ``CURRENT_DIRECTORY/chemprop_training/STEM_OF_INPUT/TIME_STAMP``)",
111
+ )
112
+ parser.add_argument(
113
+ "--remove-checkpoints",
114
+ action="store_true",
115
+ help="Remove intermediate checkpoint files after training is complete.",
116
+ )
117
+
118
+ # TODO: Add in v2.1; see if we can tell lightning how often to log training loss
119
+ # parser.add_argument(
120
+ # "--log-frequency",
121
+ # type=int,
122
+ # default=10,
123
+ # help="The number of batches between each logging of the training loss.",
124
+ # )
125
+
126
+ transfer_args = parser.add_argument_group("transfer learning args")
127
+ transfer_args.add_argument(
128
+ "--checkpoint",
129
+ type=Path,
130
+ nargs="+",
131
+ help="Path to checkpoint(s) or model file(s) for loading and overwriting weights. Accepts a single pre-trained model checkpoint (.ckpt), a single model file (.pt), a directory containing such files, or a list of paths and directories. If a directory is provided, it will recursively search for and use all (.pt) files found for prediction.",
132
+ )
133
+ transfer_args.add_argument(
134
+ "--freeze-encoder",
135
+ action="store_true",
136
+ help="Freeze the message passing layer from the checkpoint model (specified by ``--checkpoint``).",
137
+ )
138
+ transfer_args.add_argument(
139
+ "--model-frzn",
140
+ help="Path to model checkpoint file to be loaded for overwriting and freezing weights. By default, all MPNN weights are frozen with this option.",
141
+ )
142
+ transfer_args.add_argument(
143
+ "--frzn-ffn-layers",
144
+ type=int,
145
+ default=0,
146
+ help="Freeze the first ``n`` layers of the FFN from the checkpoint model (specified by ``--checkpoint``). The message passing layer should also be frozen with ``--freeze-encoder``.",
147
+ )
148
+ # transfer_args.add_argument(
149
+ # "--freeze-first-only",
150
+ # action="store_true",
151
+ # help="Determines whether or not to use checkpoint_frzn for just the first encoder. Default (False) is to use the checkpoint to freeze all encoders. (only relevant for number_of_molecules > 1, where checkpoint model has number_of_molecules = 1)",
152
+ # )
153
+
154
+ # TODO: Add in v2.1
155
+ # parser.add_argument(
156
+ # "--resume-experiment",
157
+ # action="store_true",
158
+ # help="Whether to resume the experiment. Loads test results from any folds that have already been completed and skips training those folds.",
159
+ # )
160
+ # parser.add_argument(
161
+ # "--config-path",
162
+ # help="Path to a :code:`.json` file containing arguments. Any arguments present in the config file will override arguments specified via the command line or by the defaults.",
163
+ # )
164
+ parser.add_argument(
165
+ "--ensemble-size",
166
+ type=int,
167
+ default=1,
168
+ help="Number of models in ensemble for each splitting of data",
169
+ )
170
+
171
+ # TODO: Add in v2.2
172
+ # abt_args = parser.add_argument_group("atom/bond target args")
173
+ # abt_args.add_argument(
174
+ # "--is-atom-bond-targets",
175
+ # action="store_true",
176
+ # help="Whether this is atomic/bond properties prediction.",
177
+ # )
178
+ # abt_args.add_argument(
179
+ # "--no-adding-bond-types",
180
+ # action="store_true",
181
+ # help="Whether the bond types determined by RDKit molecules added to the output of bond targets. This option is intended to be used with the :code:`is_atom_bond_targets`.",
182
+ # )
183
+ # abt_args.add_argument(
184
+ # "--keeping-atom-map",
185
+ # action="store_true",
186
+ # help="Whether RDKit molecules keep the original atom mapping. This option is intended to be used when providing atom-mapped SMILES with the :code:`is_atom_bond_targets`.",
187
+ # )
188
+ # abt_args.add_argument(
189
+ # "--no-shared-atom-bond-ffn",
190
+ # action="store_true",
191
+ # help="Whether the FFN weights for atom and bond targets should be independent between tasks.",
192
+ # )
193
+ # abt_args.add_argument(
194
+ # "--weights-ffn-num-layers",
195
+ # type=int,
196
+ # default=2,
197
+ # help="Number of layers in FFN for determining weights used in constrained targets.",
198
+ # )
199
+
200
+ mp_args = parser.add_argument_group("message passing")
201
+ mp_args.add_argument(
202
+ "--message-hidden-dim", type=int, default=300, help="Hidden dimension of the messages"
203
+ )
204
+ mp_args.add_argument(
205
+ "--message-bias", action="store_true", help="Add bias to the message passing layers"
206
+ )
207
+ mp_args.add_argument("--depth", type=int, default=3, help="Number of message passing steps")
208
+ mp_args.add_argument(
209
+ "--undirected",
210
+ action="store_true",
211
+ help="Pass messages on undirected bonds/edges (always sum the two relevant bond vectors)",
212
+ )
213
+ mp_args.add_argument(
214
+ "--dropout",
215
+ type=float,
216
+ default=0.0,
217
+ help="Dropout probability in message passing/FFN layers",
218
+ )
219
+ mp_args.add_argument(
220
+ "--mpn-shared",
221
+ action="store_true",
222
+ help="Whether to use the same message passing neural network for all input molecules (only relevant if ``number_of_molecules`` > 1)",
223
+ )
224
+ mp_args.add_argument(
225
+ "--activation",
226
+ type=uppercase,
227
+ default="RELU",
228
+ choices=list(Activation.keys()),
229
+ help="Activation function in message passing/FFN layers",
230
+ )
231
+ mp_args.add_argument(
232
+ "--aggregation",
233
+ "--agg",
234
+ default="norm",
235
+ action=LookupAction(AggregationRegistry),
236
+ help="Aggregation mode to use during graph predictor",
237
+ )
238
+ mp_args.add_argument(
239
+ "--aggregation-norm",
240
+ type=float,
241
+ default=100,
242
+ help="Normalization factor by which to divide summed up atomic features for ``norm`` aggregation",
243
+ )
244
+ mp_args.add_argument(
245
+ "--atom-messages", action="store_true", help="Pass messages on atoms rather than bonds."
246
+ )
247
+
248
+ # TODO: Add in v2.1
249
+ # mpsolv_args = parser.add_argument_group("message passing with solvent")
250
+ # mpsolv_args.add_argument(
251
+ # "--reaction-solvent",
252
+ # action="store_true",
253
+ # help="Whether to adjust the MPNN layer to take as input a reaction and a molecule, and to encode them with separate MPNNs.",
254
+ # )
255
+ # mpsolv_args.add_argument(
256
+ # "--bias-solvent",
257
+ # action="store_true",
258
+ # help="Whether to add bias to linear layers for solvent MPN if :code:`reaction_solvent` is True.",
259
+ # )
260
+ # mpsolv_args.add_argument(
261
+ # "--hidden-size-solvent",
262
+ # type=int,
263
+ # default=300,
264
+ # help="Dimensionality of hidden layers in solvent MPN if :code:`reaction_solvent` is True.",
265
+ # )
266
+ # mpsolv_args.add_argument(
267
+ # "--depth-solvent",
268
+ # type=int,
269
+ # default=3,
270
+ # help="Number of message passing steps for solvent if :code:`reaction_solvent` is True.",
271
+ # )
272
+
273
+ ffn_args = parser.add_argument_group("FFN args")
274
+ ffn_args.add_argument(
275
+ "--ffn-hidden-dim", type=int, default=300, help="Hidden dimension in the FFN top model"
276
+ )
277
+ ffn_args.add_argument( # TODO: the default in v1 was 2. (see weights_ffn_num_layers option) Do we really want the default to now be 1?
278
+ "--ffn-num-layers", type=int, default=1, help="Number of layers in FFN top model"
279
+ )
280
+ # TODO: Decide if we want to implment this in v2
281
+ # ffn_args.add_argument(
282
+ # "--features-only",
283
+ # action="store_true",
284
+ # help="Use only the additional features in an FFN, no graph network.",
285
+ # )
286
+
287
+ extra_mpnn_args = parser.add_argument_group("extra MPNN args")
288
+ extra_mpnn_args.add_argument(
289
+ "--batch-norm", action="store_true", help="Turn on batch normalization after aggregation"
290
+ )
291
+ extra_mpnn_args.add_argument(
292
+ "--multiclass-num-classes",
293
+ type=int,
294
+ default=3,
295
+ help="Number of classes when running multiclass classification",
296
+ )
297
+ # TODO: Add in v2.1
298
+ # extra_mpnn_args.add_argument(
299
+ # "--spectral-activation",
300
+ # default="exp",
301
+ # choices=["softplus", "exp"],
302
+ # help="Indicates which function to use in task_type spectra training to constrain outputs to be positive.",
303
+ # )
304
+
305
+ train_data_args = parser.add_argument_group("training input data args")
306
+ train_data_args.add_argument(
307
+ "-w",
308
+ "--weight-column",
309
+ help="Name of the column in the input CSV containing individual data weights",
310
+ )
311
+ train_data_args.add_argument(
312
+ "--target-columns",
313
+ nargs="+",
314
+ help="Name of the columns containing target values (by default, uses all columns except the SMILES column and the ``ignore_columns``)",
315
+ )
316
+ train_data_args.add_argument(
317
+ "--ignore-columns",
318
+ nargs="+",
319
+ help="Name of the columns to ignore when ``target_columns`` is not provided",
320
+ )
321
+ train_data_args.add_argument(
322
+ "--no-cache",
323
+ action="store_true",
324
+ help="Turn off caching the featurized ``MolGraph`` s at the beginning of training",
325
+ )
326
+ train_data_args.add_argument(
327
+ "--splits-column",
328
+ help="Name of the column in the input CSV file containing 'train', 'val', or 'test' for each row.",
329
+ )
330
+ # TODO: Add in v2.1
331
+ # train_data_args.add_argument(
332
+ # "--spectra-phase-mask-path",
333
+ # help="Path to a file containing a phase mask array, used for excluding particular regions in spectra predictions.",
334
+ # )
335
+
336
+ train_args = parser.add_argument_group("training args")
337
+ train_args.add_argument(
338
+ "-t",
339
+ "--task-type",
340
+ default="regression",
341
+ action=LookupAction(PredictorRegistry),
342
+ help="Type of dataset (determines the default loss function used during training, defaults to ``regression``)",
343
+ )
344
+ train_args.add_argument(
345
+ "-l",
346
+ "--loss-function",
347
+ action=LookupAction(LossFunctionRegistry),
348
+ help="Loss function to use during training (will use the default loss function for the given task type if not specified)",
349
+ )
350
+ train_args.add_argument(
351
+ "--v-kl",
352
+ "--evidential-regularization",
353
+ type=float,
354
+ default=0.0,
355
+ help="Specify the value used in regularization for evidential loss function. The default value recommended by Soleimany et al. (2021) is 0.2. However, the optimal value is dataset-dependent, so it is recommended that users test different values to find the best value for their model.",
356
+ )
357
+
358
+ train_args.add_argument(
359
+ "--eps", type=float, default=1e-8, help="Evidential regularization epsilon"
360
+ )
361
+ train_args.add_argument(
362
+ "--alpha", type=float, default=0.1, help="Target error bounds for quantile interval loss"
363
+ )
364
+ # TODO: Add in v2.1
365
+ # train_args.add_argument( # TODO: Is threshold the same thing as the spectra target floor? I'm not sure but combined them.
366
+ # "-T",
367
+ # "--threshold",
368
+ # "--spectra-target-floor",
369
+ # type=float,
370
+ # default=1e-8,
371
+ # help="spectral threshold limit. v1 help string: Values in targets for dataset type spectra are replaced with this value, intended to be a small positive number used to enforce positive values.",
372
+ # )
373
+ train_args.add_argument(
374
+ "--metrics",
375
+ "--metric",
376
+ nargs="+",
377
+ action=LookupAction(MetricRegistry),
378
+ help="Specify the evaluation metrics. If unspecified, chemprop will use the following metrics for given dataset types: regression -> ``rmse``, classification -> ``roc``, multiclass -> ``ce`` ('cross entropy'), spectral -> ``sid``. If multiple metrics are provided, the 0-th one will be used for early stopping and checkpointing.",
379
+ )
380
+ train_args.add_argument(
381
+ "--tracking-metric",
382
+ default="val_loss",
383
+ help="The metric to track for early stopping and checkpointing. Defaults to the criterion used during training.",
384
+ )
385
+ train_args.add_argument(
386
+ "--show-individual-scores",
387
+ action="store_true",
388
+ help="Show all scores for individual targets, not just average, at the end.",
389
+ )
390
+ train_args.add_argument(
391
+ "--task-weights",
392
+ nargs="+",
393
+ type=float,
394
+ help="Weights to apply for whole tasks in the loss function",
395
+ )
396
+ train_args.add_argument(
397
+ "--warmup-epochs",
398
+ type=int,
399
+ default=2,
400
+ help="Number of epochs during which learning rate increases linearly from ``init_lr`` to ``max_lr`` (afterwards, learning rate decreases exponentially from ``max_lr`` to ``final_lr``)",
401
+ )
402
+
403
+ train_args.add_argument("--init-lr", type=float, default=1e-4, help="Initial learning rate.")
404
+ train_args.add_argument("--max-lr", type=float, default=1e-3, help="Maximum learning rate.")
405
+ train_args.add_argument("--final-lr", type=float, default=1e-4, help="Final learning rate.")
406
+ train_args.add_argument("--epochs", type=int, default=50, help="Number of epochs to train over")
407
+ train_args.add_argument(
408
+ "--patience",
409
+ type=int,
410
+ default=None,
411
+ help="Number of epochs to wait for improvement before early stopping",
412
+ )
413
+ train_args.add_argument(
414
+ "--grad-clip",
415
+ type=float,
416
+ help="Passed directly to the lightning trainer which controls grad clipping (see the ``Trainer()`` docstring for details)",
417
+ )
418
+ train_args.add_argument(
419
+ "--class-balance",
420
+ action="store_true",
421
+ help="Ensures each training batch contains an equal number of positive and negative samples.",
422
+ )
423
+
424
+ split_args = parser.add_argument_group("split args")
425
+ split_args.add_argument(
426
+ "--split",
427
+ "--split-type",
428
+ type=uppercase,
429
+ default="RANDOM",
430
+ choices=list(SplitType.keys()),
431
+ help="Method of splitting the data into train/val/test (case insensitive)",
432
+ )
433
+ split_args.add_argument(
434
+ "--split-sizes",
435
+ type=float,
436
+ nargs=3,
437
+ default=[0.8, 0.1, 0.1],
438
+ help="Split proportions for train/validation/test sets",
439
+ )
440
+ split_args.add_argument(
441
+ "--split-key-molecule",
442
+ type=int,
443
+ default=0,
444
+ help="Specify the index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used (e.g., ``scaffold_balanced`` or ``random_with_repeated_smiles``). Note that this index begins with zero for the first molecule.",
445
+ )
446
+ split_args.add_argument("--num-replicates", type=int, default=1, help="Number of replicates.")
447
+ split_args.add_argument("-k", "--num-folds", help=_CV_REMOVAL_ERROR)
448
+ split_args.add_argument(
449
+ "--save-smiles-splits",
450
+ action="store_true",
451
+ help="Whether to store the SMILES in each train/val/test split",
452
+ )
453
+ split_args.add_argument(
454
+ "--splits-file",
455
+ type=Path,
456
+ help="Path to a JSON file containing pre-defined splits for the input data, formatted as a list of dictionaries with keys ``train``, ``val``, and ``test`` and values as lists of indices or formatted strings (e.g. [0, 1, 2, 4] or '0-2,4')",
457
+ )
458
+ split_args.add_argument(
459
+ "--data-seed",
460
+ type=int,
461
+ default=0,
462
+ help="Specify the random seed to use when splitting data into train/val/test sets. When ``--num-replicates`` > 1, the first replicate uses this seed and all subsequent replicates add 1 to the seed (also used for shuffling data in ``build_dataloader`` when ``shuffle`` is True).",
463
+ )
464
+
465
+ parser.add_argument(
466
+ "--pytorch-seed",
467
+ type=int,
468
+ default=None,
469
+ help="Seed for PyTorch randomness (e.g., random initial weights)",
470
+ )
471
+
472
+ return parser
473
+
474
+
475
+ def process_train_args(args: Namespace) -> Namespace:
476
+ if args.output_dir is None:
477
+ args.output_dir = CHEMPROP_TRAIN_DIR / args.data_path.stem / NOW
478
+
479
+ return args
480
+
481
+
482
+ def validate_train_args(args):
483
+ if args.config_path is None and args.data_path is None:
484
+ raise ArgumentError(argument=None, message="Data path must be provided for training.")
485
+
486
+ if args.num_folds is not None: # i.e. user-specified
487
+ raise ArgumentError(argument=None, message=_CV_REMOVAL_ERROR)
488
+
489
+ if args.data_path.suffix not in [".csv"]:
490
+ raise ArgumentError(
491
+ argument=None, message=f"Input data must be a CSV file. Got {args.data_path}"
492
+ )
493
+
494
+ if args.epochs != -1 and args.epochs <= args.warmup_epochs:
495
+ raise ArgumentError(
496
+ argument=None,
497
+ message=f"The number of epochs should be higher than the number of epochs during warmup. Got {args.epochs} epochs and {args.warmup_epochs} warmup epochs",
498
+ )
499
+
500
+ # TODO: model_frzn is deprecated and then remove in v2.2
501
+ if args.checkpoint is not None and args.model_frzn is not None:
502
+ raise ArgumentError(
503
+ argument=None,
504
+ message="`--checkpoint` and `--model-frzn` cannot be used at the same time.",
505
+ )
506
+
507
+ if "--model-frzn" in sys.argv:
508
+ logger.warning(
509
+ "`--model-frzn` is deprecated and will be removed in v2.2. "
510
+ "Please use `--checkpoint` with `--freeze-encoder` instead."
511
+ )
512
+
513
+ if args.freeze_encoder and args.checkpoint is None:
514
+ raise ArgumentError(
515
+ argument=None,
516
+ message="`--freeze-encoder` can only be used when `--checkpoint` is used.",
517
+ )
518
+
519
+ if args.frzn_ffn_layers > 0:
520
+ if args.checkpoint is None and args.model_frzn is None:
521
+ raise ArgumentError(
522
+ argument=None,
523
+ message="`--frzn-ffn-layers` can only be used when `--checkpoint` or `--model-frzn` (depreciated in v2.1) is used.",
524
+ )
525
+ if args.checkpoint is not None and not args.freeze_encoder:
526
+ raise ArgumentError(
527
+ argument=None,
528
+ message="To freeze the first `n` layers of the FFN via `--frzn-ffn-layers`. The message passing layer should also be frozen with `--freeze-encoder`.",
529
+ )
530
+
531
+ if args.class_balance and args.task_type != "classification":
532
+ raise ArgumentError(
533
+ argument=None, message="Class balance is only applicable for classification tasks."
534
+ )
535
+
536
+ valid_tracking_metrics = (
537
+ args.metrics or [PredictorRegistry[args.task_type]._T_default_metric.alias]
538
+ ) + ["val_loss"]
539
+ if args.tracking_metric not in valid_tracking_metrics:
540
+ raise ArgumentError(
541
+ argument=None,
542
+ message=f"Tracking metric must be one of {','.join(valid_tracking_metrics)}. "
543
+ f"Got {args.tracking_metric}. Additional tracking metric options can be specified with "
544
+ "the `--metrics` flag.",
545
+ )
546
+
547
+ input_cols, target_cols = get_column_names(
548
+ args.data_path,
549
+ args.smiles_columns,
550
+ args.reaction_columns,
551
+ args.target_columns,
552
+ args.ignore_columns,
553
+ args.splits_column,
554
+ args.weight_column,
555
+ args.no_header_row,
556
+ )
557
+
558
+ args.input_columns = input_cols
559
+ args.target_columns = target_cols
560
+
561
+ return args
562
+
563
+
564
+ def normalize_inputs(train_dset, val_dset, args):
565
+ multicomponent = isinstance(train_dset, MulticomponentDataset)
566
+ num_components = train_dset.n_components if multicomponent else 1
567
+
568
+ X_d_transform = None
569
+ V_f_transforms = [nn.Identity()] * num_components
570
+ E_f_transforms = [nn.Identity()] * num_components
571
+ V_d_transforms = [None] * num_components
572
+ graph_transforms = []
573
+
574
+ d_xd = train_dset.d_xd
575
+ d_vf = train_dset.d_vf
576
+ d_ef = train_dset.d_ef
577
+ d_vd = train_dset.d_vd
578
+
579
+ if d_xd > 0 and not args.no_descriptor_scaling:
580
+ scaler = train_dset.normalize_inputs("X_d")
581
+ val_dset.normalize_inputs("X_d", scaler)
582
+
583
+ scaler = scaler if not isinstance(scaler, list) else scaler[0]
584
+
585
+ if scaler is not None:
586
+ logger.info(
587
+ f"Descriptors: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
588
+ )
589
+ X_d_transform = ScaleTransform.from_standard_scaler(scaler)
590
+
591
+ if d_vf > 0 and not args.no_atom_feature_scaling:
592
+ scaler = train_dset.normalize_inputs("V_f")
593
+ val_dset.normalize_inputs("V_f", scaler)
594
+
595
+ scalers = [scaler] if not isinstance(scaler, list) else scaler
596
+
597
+ for i, scaler in enumerate(scalers):
598
+ if scaler is None:
599
+ continue
600
+
601
+ logger.info(
602
+ f"Atom features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
603
+ )
604
+ featurizer = (
605
+ train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
606
+ )
607
+ V_f_transforms[i] = ScaleTransform.from_standard_scaler(
608
+ scaler, pad=featurizer.atom_fdim - featurizer.extra_atom_fdim
609
+ )
610
+
611
+ if d_ef > 0 and not args.no_bond_feature_scaling:
612
+ scaler = train_dset.normalize_inputs("E_f")
613
+ val_dset.normalize_inputs("E_f", scaler)
614
+
615
+ scalers = [scaler] if not isinstance(scaler, list) else scaler
616
+
617
+ for i, scaler in enumerate(scalers):
618
+ if scaler is None:
619
+ continue
620
+
621
+ logger.info(
622
+ f"Bond features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
623
+ )
624
+ featurizer = (
625
+ train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
626
+ )
627
+ E_f_transforms[i] = ScaleTransform.from_standard_scaler(
628
+ scaler, pad=featurizer.bond_fdim - featurizer.extra_bond_fdim
629
+ )
630
+
631
+ for V_f_transform, E_f_transform in zip(V_f_transforms, E_f_transforms):
632
+ graph_transforms.append(GraphTransform(V_f_transform, E_f_transform))
633
+
634
+ if d_vd > 0 and not args.no_atom_descriptor_scaling:
635
+ scaler = train_dset.normalize_inputs("V_d")
636
+ val_dset.normalize_inputs("V_d", scaler)
637
+
638
+ scalers = [scaler] if not isinstance(scaler, list) else scaler
639
+
640
+ for i, scaler in enumerate(scalers):
641
+ if scaler is None:
642
+ continue
643
+
644
+ logger.info(
645
+ f"Atom descriptors for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
646
+ )
647
+ V_d_transforms[i] = ScaleTransform.from_standard_scaler(scaler)
648
+
649
+ return X_d_transform, graph_transforms, V_d_transforms
650
+
651
+
652
+ def load_and_use_pretrained_model_scalers(model_path: Path, train_dset, val_dset) -> None:
653
+ if isinstance(train_dset, MulticomponentDataset):
654
+ _model = MulticomponentMPNN.load_from_file(model_path)
655
+ blocks = _model.message_passing.blocks
656
+ train_dsets = train_dset.datasets
657
+ val_dsets = val_dset.datasets
658
+ else:
659
+ _model = MPNN.load_from_file(model_path)
660
+ blocks = [_model.message_passing]
661
+ train_dsets = [train_dset]
662
+ val_dsets = [val_dset]
663
+
664
+ for i in range(len(blocks)):
665
+ if isinstance(_model.X_d_transform, ScaleTransform):
666
+ scaler = _model.X_d_transform.to_standard_scaler()
667
+ train_dsets[i].normalize_inputs("X_d", scaler)
668
+ val_dsets[i].normalize_inputs("X_d", scaler)
669
+
670
+ if isinstance(blocks[i].graph_transform, GraphTransform):
671
+ if isinstance(blocks[i].graph_transform.V_transform, ScaleTransform):
672
+ V_anti_pad = (
673
+ train_dsets[i].featurizer.atom_fdim - train_dsets[i].featurizer.extra_atom_fdim
674
+ )
675
+ scaler = blocks[i].graph_transform.V_transform.to_standard_scaler(
676
+ anti_pad=V_anti_pad
677
+ )
678
+ train_dsets[i].normalize_inputs("V_f", scaler)
679
+ val_dsets[i].normalize_inputs("V_f", scaler)
680
+ if isinstance(blocks[i].graph_transform.E_transform, ScaleTransform):
681
+ E_anti_pad = (
682
+ train_dsets[i].featurizer.bond_fdim - train_dsets[i].featurizer.extra_bond_fdim
683
+ )
684
+ scaler = blocks[i].graph_transform.E_transform.to_standard_scaler(
685
+ anti_pad=E_anti_pad
686
+ )
687
+ train_dsets[i].normalize_inputs("E_f", scaler)
688
+ val_dsets[i].normalize_inputs("E_f", scaler)
689
+
690
+ if isinstance(blocks[i].V_d_transform, ScaleTransform):
691
+ scaler = blocks[i].V_d_transform.to_standard_scaler()
692
+ train_dsets[i].normalize_inputs("V_d", scaler)
693
+ val_dsets[i].normalize_inputs("V_d", scaler)
694
+
695
+ if isinstance(_model.predictor.output_transform, UnscaleTransform):
696
+ scaler = _model.predictor.output_transform.to_standard_scaler()
697
+ train_dset.normalize_targets(scaler)
698
+ val_dset.normalize_targets(scaler)
699
+
700
+
701
+ def save_config(parser: ArgumentParser, args: Namespace, config_path: Path):
702
+ config_args = deepcopy(args)
703
+ for key, value in vars(config_args).items():
704
+ if isinstance(value, Path):
705
+ setattr(config_args, key, str(value))
706
+
707
+ for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
708
+ if getattr(config_args, key) is not None:
709
+ for index, path in getattr(config_args, key).items():
710
+ getattr(config_args, key)[index] = str(path)
711
+
712
+ parser.write_config_file(parsed_namespace=config_args, output_file_paths=[str(config_path)])
713
+
714
+
715
+ def save_smiles_splits(args: Namespace, output_dir, train_dset, val_dset, test_dset):
716
+ match (args.smiles_columns, args.reaction_columns):
717
+ case [_, None]:
718
+ column_labels = deepcopy(args.smiles_columns)
719
+ case [None, _]:
720
+ column_labels = deepcopy(args.reaction_columns)
721
+ case _:
722
+ column_labels = deepcopy(args.smiles_columns)
723
+ column_labels.extend(args.reaction_columns)
724
+
725
+ train_smis = train_dset.names
726
+ df_train = pd.DataFrame(train_smis, columns=column_labels)
727
+ df_train.to_csv(output_dir / "train_smiles.csv", index=False)
728
+
729
+ val_smis = val_dset.names
730
+ df_val = pd.DataFrame(val_smis, columns=column_labels)
731
+ df_val.to_csv(output_dir / "val_smiles.csv", index=False)
732
+
733
+ if test_dset is not None:
734
+ test_smis = test_dset.names
735
+ df_test = pd.DataFrame(test_smis, columns=column_labels)
736
+ df_test.to_csv(output_dir / "test_smiles.csv", index=False)
737
+
738
+
739
+ def build_splits(args, format_kwargs, featurization_kwargs):
740
+ """build the train/val/test splits"""
741
+ logger.info(f"Pulling data from file: {args.data_path}")
742
+ all_data = build_data_from_files(
743
+ args.data_path,
744
+ p_descriptors=args.descriptors_path,
745
+ p_atom_feats=args.atom_features_path,
746
+ p_bond_feats=args.bond_features_path,
747
+ p_atom_descs=args.atom_descriptors_path,
748
+ **format_kwargs,
749
+ **featurization_kwargs,
750
+ )
751
+
752
+ if args.splits_column is not None:
753
+ df = pd.read_csv(
754
+ args.data_path, header=None if args.no_header_row else "infer", index_col=False
755
+ )
756
+ grouped = df.groupby(df[args.splits_column].str.lower())
757
+ train_indices = grouped.groups.get("train", pd.Index([])).tolist()
758
+ val_indices = grouped.groups.get("val", pd.Index([])).tolist()
759
+ test_indices = grouped.groups.get("test", pd.Index([])).tolist()
760
+ train_indices, val_indices, test_indices = [train_indices], [val_indices], [test_indices]
761
+
762
+ elif args.splits_file is not None:
763
+ with open(args.splits_file, "rb") as json_file:
764
+ split_idxss = json.load(json_file)
765
+ train_indices = [parse_indices(d["train"]) for d in split_idxss]
766
+ val_indices = [parse_indices(d["val"]) for d in split_idxss]
767
+ test_indices = [parse_indices(d["test"]) for d in split_idxss]
768
+ args.num_replicates = len(split_idxss)
769
+
770
+ else:
771
+ splitting_data = all_data[args.split_key_molecule]
772
+ if isinstance(splitting_data[0], ReactionDatapoint):
773
+ splitting_mols = [datapoint.rct for datapoint in splitting_data]
774
+ else:
775
+ splitting_mols = [datapoint.mol for datapoint in splitting_data]
776
+ train_indices, val_indices, test_indices = make_split_indices(
777
+ splitting_mols, args.split, args.split_sizes, args.data_seed, args.num_replicates
778
+ )
779
+
780
+ train_data, val_data, test_data = split_data_by_indices(
781
+ all_data, train_indices, val_indices, test_indices
782
+ )
783
+ for i_split in range(len(train_data)):
784
+ sizes = [len(train_data[i_split][0]), len(val_data[i_split][0]), len(test_data[i_split][0])]
785
+ logger.info(f"train/val/test split_{i_split} sizes: {sizes}")
786
+
787
+ return train_data, val_data, test_data
788
+
789
+
790
+ def summarize(
791
+ target_cols: list[str], task_type: str, dataset: _MolGraphDatasetMixin
792
+ ) -> tuple[list, list]:
793
+ if task_type in [
794
+ "regression",
795
+ "regression-mve",
796
+ "regression-evidential",
797
+ "regression-quantile",
798
+ ]:
799
+ if isinstance(dataset, MulticomponentDataset):
800
+ y = dataset.datasets[0].Y
801
+ else:
802
+ y = dataset.Y
803
+ y_mean = np.nanmean(y, axis=0)
804
+ y_std = np.nanstd(y, axis=0)
805
+ y_median = np.nanmedian(y, axis=0)
806
+ mean_dev_abs = np.abs(y - y_mean)
807
+ num_targets = np.sum(~np.isnan(y), axis=0)
808
+ frac_1_sigma = np.sum((mean_dev_abs < y_std), axis=0) / num_targets
809
+ frac_2_sigma = np.sum((mean_dev_abs < 2 * y_std), axis=0) / num_targets
810
+
811
+ column_headers = ["Statistic"] + [f"Value ({target_cols[i]})" for i in range(y.shape[1])]
812
+ table_rows = [
813
+ ["Num. smiles"] + [f"{len(y)}" for i in range(y.shape[1])],
814
+ ["Num. targets"] + [f"{num_targets[i]}" for i in range(y.shape[1])],
815
+ ["Num. NaN"] + [f"{len(y) - num_targets[i]}" for i in range(y.shape[1])],
816
+ ["Mean"] + [f"{mean:0.3g}" for mean in y_mean],
817
+ ["Std. dev."] + [f"{std:0.3g}" for std in y_std],
818
+ ["Median"] + [f"{median:0.3g}" for median in y_median],
819
+ ["% within 1 s.d."] + [f"{sigma:0.0%}" for sigma in frac_1_sigma],
820
+ ["% within 2 s.d."] + [f"{sigma:0.0%}" for sigma in frac_2_sigma],
821
+ ]
822
+ return (column_headers, table_rows)
823
+ elif task_type in [
824
+ "classification",
825
+ "classification-dirichlet",
826
+ "multiclass",
827
+ "multiclass-dirichlet",
828
+ ]:
829
+ if isinstance(dataset, MulticomponentDataset):
830
+ y = dataset.datasets[0].Y
831
+ else:
832
+ y = dataset.Y
833
+
834
+ mask = np.isnan(y)
835
+ classes = np.sort(np.unique(y[~mask]))
836
+
837
+ class_counts = np.stack([(classes[:, None] == y[:, i]).sum(1) for i in range(y.shape[1])])
838
+ class_fracs = class_counts / y.shape[0]
839
+ nan_count = np.nansum(mask, axis=0)
840
+ nan_frac = nan_count / y.shape[0]
841
+
842
+ column_headers = ["Class"] + [f"Count/Percent {target_cols[i]}" for i in range(y.shape[1])]
843
+
844
+ table_rows = [
845
+ [f"{k}"] + [f"{class_counts[j, i]}/{class_fracs[j, i]:0.0%}" for j in range(y.shape[1])]
846
+ for i, k in enumerate(classes)
847
+ ]
848
+
849
+ nan_row = ["NaN"] + [f"{nan_count[i]}/{nan_frac[i]:0.0%}" for i in range(y.shape[1])]
850
+ table_rows.append(nan_row)
851
+
852
+ total_row = ["Total"] + [f"{y.shape[0]}/{100.00}%" for i in range(y.shape[1])]
853
+ table_rows.append(total_row)
854
+
855
+ return (column_headers, table_rows)
856
+ else:
857
+ raise ValueError(f"unsupported task type! Task type '{task_type}' was not recognized.")
858
+
859
+
860
+ def build_table(column_headers: list[str], table_rows: list[str], title: str | None = None) -> str:
861
+ right_justified_columns = [
862
+ Column(header=column_header, justify="right") for column_header in column_headers
863
+ ]
864
+ table = Table(*right_justified_columns, title=title)
865
+ for row in table_rows:
866
+ table.add_row(*row)
867
+
868
+ console = Console(record=True, file=StringIO(), width=200)
869
+ console.print(table)
870
+ return console.export_text()
871
+
872
+
873
+ def build_datasets(args, train_data, val_data, test_data):
874
+ """build the train/val/test datasets, where :attr:`test_data` may be None"""
875
+ multicomponent = len(train_data) > 1
876
+ if multicomponent:
877
+ train_dsets = [
878
+ make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
879
+ for data in train_data
880
+ ]
881
+ val_dsets = [
882
+ make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
883
+ for data in val_data
884
+ ]
885
+ train_dset = MulticomponentDataset(train_dsets)
886
+ val_dset = MulticomponentDataset(val_dsets)
887
+ if len(test_data[0]) > 0:
888
+ test_dsets = [
889
+ make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
890
+ for data in test_data
891
+ ]
892
+ test_dset = MulticomponentDataset(test_dsets)
893
+ else:
894
+ test_dset = None
895
+ else:
896
+ train_data = train_data[0]
897
+ val_data = val_data[0]
898
+ test_data = test_data[0]
899
+ train_dset = make_dataset(train_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
900
+ val_dset = make_dataset(val_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
901
+ if len(test_data) > 0:
902
+ test_dset = make_dataset(test_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
903
+ else:
904
+ test_dset = None
905
+ if args.task_type != "spectral":
906
+ for dataset, label in zip(
907
+ [train_dset, val_dset, test_dset], ["Training", "Validation", "Test"]
908
+ ):
909
+ column_headers, table_rows = summarize(args.target_columns, args.task_type, dataset)
910
+ output = build_table(column_headers, table_rows, f"Summary of {label} Data")
911
+ logger.info("\n" + output)
912
+
913
+ return train_dset, val_dset, test_dset
914
+
915
+
916
+ def build_model(
917
+ args,
918
+ train_dset: MolGraphDataset | MulticomponentDataset,
919
+ output_transform: UnscaleTransform,
920
+ input_transforms: tuple[ScaleTransform, list[GraphTransform], list[ScaleTransform]],
921
+ ) -> MPNN:
922
+ mp_cls = AtomMessagePassing if args.atom_messages else BondMessagePassing
923
+
924
+ X_d_transform, graph_transforms, V_d_transforms = input_transforms
925
+ if isinstance(train_dset, MulticomponentDataset):
926
+ mp_blocks = [
927
+ mp_cls(
928
+ train_dset.datasets[i].featurizer.atom_fdim,
929
+ train_dset.datasets[i].featurizer.bond_fdim,
930
+ d_h=args.message_hidden_dim,
931
+ d_vd=(
932
+ train_dset.datasets[i].d_vd
933
+ if isinstance(train_dset.datasets[i], MoleculeDataset)
934
+ else 0
935
+ ),
936
+ bias=args.message_bias,
937
+ depth=args.depth,
938
+ undirected=args.undirected,
939
+ dropout=args.dropout,
940
+ activation=args.activation,
941
+ V_d_transform=V_d_transforms[i],
942
+ graph_transform=graph_transforms[i],
943
+ )
944
+ for i in range(train_dset.n_components)
945
+ ]
946
+ if args.mpn_shared:
947
+ if args.reaction_columns is not None and args.smiles_columns is not None:
948
+ raise ArgumentError(
949
+ argument=None,
950
+ message="Cannot use shared MPNN with both molecule and reaction data.",
951
+ )
952
+
953
+ mp_block = MulticomponentMessagePassing(mp_blocks, train_dset.n_components, args.mpn_shared)
954
+ # NOTE(degraff): this if/else block should be handled by the init of MulticomponentMessagePassing
955
+ # if args.mpn_shared:
956
+ # mp_block = MulticomponentMessagePassing(mp_blocks[0], n_components, args.mpn_shared)
957
+ # else:
958
+ d_xd = train_dset.datasets[0].d_xd
959
+ n_tasks = train_dset.datasets[0].Y.shape[1]
960
+ mpnn_cls = MulticomponentMPNN
961
+ else:
962
+ mp_block = mp_cls(
963
+ train_dset.featurizer.atom_fdim,
964
+ train_dset.featurizer.bond_fdim,
965
+ d_h=args.message_hidden_dim,
966
+ d_vd=train_dset.d_vd if isinstance(train_dset, MoleculeDataset) else 0,
967
+ bias=args.message_bias,
968
+ depth=args.depth,
969
+ undirected=args.undirected,
970
+ dropout=args.dropout,
971
+ activation=args.activation,
972
+ V_d_transform=V_d_transforms[0],
973
+ graph_transform=graph_transforms[0],
974
+ )
975
+ d_xd = train_dset.d_xd
976
+ n_tasks = train_dset.Y.shape[1]
977
+ mpnn_cls = MPNN
978
+
979
+ agg = Factory.build(AggregationRegistry[args.aggregation], norm=args.aggregation_norm)
980
+ predictor_cls = PredictorRegistry[args.task_type]
981
+ if args.loss_function is not None:
982
+ task_weights = torch.ones(n_tasks) if args.task_weights is None else args.task_weights
983
+ criterion = Factory.build(
984
+ LossFunctionRegistry[args.loss_function],
985
+ task_weights=task_weights,
986
+ v_kl=args.v_kl,
987
+ # threshold=args.threshold, TODO: Add in v2.1
988
+ eps=args.eps,
989
+ alpha=args.alpha,
990
+ )
991
+ else:
992
+ criterion = None
993
+ if args.metrics is not None:
994
+ metrics = [Factory.build(MetricRegistry[metric]) for metric in args.metrics]
995
+ else:
996
+ metrics = None
997
+
998
+ predictor = Factory.build(
999
+ predictor_cls,
1000
+ input_dim=mp_block.output_dim + d_xd,
1001
+ n_tasks=n_tasks,
1002
+ hidden_dim=args.ffn_hidden_dim,
1003
+ n_layers=args.ffn_num_layers,
1004
+ dropout=args.dropout,
1005
+ activation=args.activation,
1006
+ criterion=criterion,
1007
+ task_weights=args.task_weights,
1008
+ n_classes=args.multiclass_num_classes,
1009
+ output_transform=output_transform,
1010
+ # spectral_activation=args.spectral_activation, TODO: Add in v2.1
1011
+ )
1012
+
1013
+ if args.loss_function is None:
1014
+ logger.info(
1015
+ f"No loss function was specified! Using class default: {predictor_cls._T_default_criterion}"
1016
+ )
1017
+
1018
+ return mpnn_cls(
1019
+ mp_block,
1020
+ agg,
1021
+ predictor,
1022
+ args.batch_norm,
1023
+ metrics,
1024
+ args.warmup_epochs,
1025
+ args.init_lr,
1026
+ args.max_lr,
1027
+ args.final_lr,
1028
+ X_d_transform=X_d_transform,
1029
+ )
1030
+
1031
+
1032
+ def train_model(
1033
+ args, train_loader, val_loader, test_loader, output_dir, output_transform, input_transforms
1034
+ ):
1035
+ if args.checkpoint is not None:
1036
+ model_paths = find_models(args.checkpoint)
1037
+ if args.ensemble_size != len(model_paths):
1038
+ logger.warning(
1039
+ f"The number of models in ensemble for each splitting of data is set to {len(model_paths)}."
1040
+ )
1041
+ args.ensemble_size = len(model_paths)
1042
+
1043
+ for model_idx in range(args.ensemble_size):
1044
+ model_output_dir = output_dir / f"model_{model_idx}"
1045
+ model_output_dir.mkdir(exist_ok=True, parents=True)
1046
+
1047
+ if args.pytorch_seed is None:
1048
+ seed = torch.seed()
1049
+ deterministic = False
1050
+ else:
1051
+ seed = args.pytorch_seed + model_idx
1052
+ deterministic = True
1053
+
1054
+ torch.manual_seed(seed)
1055
+
1056
+ if args.checkpoint or args.model_frzn is not None:
1057
+ mpnn_cls = (
1058
+ MulticomponentMPNN
1059
+ if isinstance(train_loader.dataset, MulticomponentDataset)
1060
+ else MPNN
1061
+ )
1062
+ model_path = model_paths[model_idx] if args.checkpoint else args.model_frzn
1063
+ model = mpnn_cls.load_from_file(model_path)
1064
+
1065
+ if args.checkpoint:
1066
+ model.apply(
1067
+ lambda m: setattr(m, "p", args.dropout)
1068
+ if isinstance(m, torch.nn.Dropout)
1069
+ else None
1070
+ )
1071
+
1072
+ # TODO: model_frzn is deprecated and then remove in v2.2
1073
+ if args.model_frzn or args.freeze_encoder:
1074
+ model.message_passing.apply(lambda module: module.requires_grad_(False))
1075
+ model.message_passing.eval()
1076
+ model.bn.apply(lambda module: module.requires_grad_(False))
1077
+ model.bn.eval()
1078
+ for idx in range(args.frzn_ffn_layers):
1079
+ model.predictor.ffn[idx].requires_grad_(False)
1080
+ model.predictor.ffn[idx + 1].eval()
1081
+ else:
1082
+ model = build_model(args, train_loader.dataset, output_transform, input_transforms)
1083
+ logger.info(model)
1084
+
1085
+ try:
1086
+ trainer_logger = TensorBoardLogger(
1087
+ model_output_dir, "trainer_logs", default_hp_metric=False
1088
+ )
1089
+ except ModuleNotFoundError as e:
1090
+ logger.warning(
1091
+ f"Unable to import TensorBoardLogger, reverting to CSVLogger (original error: {e})."
1092
+ )
1093
+ trainer_logger = CSVLogger(model_output_dir, "trainer_logs")
1094
+
1095
+ if args.tracking_metric == "val_loss":
1096
+ T_tracking_metric = model.criterion.__class__
1097
+ tracking_metric = args.tracking_metric
1098
+ else:
1099
+ T_tracking_metric = MetricRegistry[args.tracking_metric]
1100
+ tracking_metric = "val/" + args.tracking_metric
1101
+
1102
+ monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
1103
+ logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
1104
+
1105
+ if args.remove_checkpoints:
1106
+ temp_dir = TemporaryDirectory()
1107
+ checkpoint_dir = Path(temp_dir.name)
1108
+ else:
1109
+ checkpoint_dir = model_output_dir
1110
+
1111
+ checkpoint_filename = (
1112
+ f"best-epoch={{epoch}}-{tracking_metric.replace('/', '_')}="
1113
+ f"{{{tracking_metric}:.2f}}"
1114
+ )
1115
+ checkpointing = ModelCheckpoint(
1116
+ checkpoint_dir / "checkpoints",
1117
+ checkpoint_filename,
1118
+ tracking_metric,
1119
+ mode=monitor_mode,
1120
+ save_last=True,
1121
+ auto_insert_metric_name=False,
1122
+ )
1123
+
1124
+ if args.epochs != -1:
1125
+ patience = args.patience if args.patience is not None else args.epochs
1126
+ early_stopping = EarlyStopping(tracking_metric, patience=patience, mode=monitor_mode)
1127
+ callbacks = [checkpointing, early_stopping]
1128
+ else:
1129
+ callbacks = [checkpointing]
1130
+
1131
+ trainer = pl.Trainer(
1132
+ logger=trainer_logger,
1133
+ enable_progress_bar=True,
1134
+ accelerator=args.accelerator,
1135
+ devices=args.devices,
1136
+ max_epochs=args.epochs,
1137
+ callbacks=callbacks,
1138
+ gradient_clip_val=args.grad_clip,
1139
+ deterministic=deterministic,
1140
+ )
1141
+ trainer.fit(model, train_loader, val_loader)
1142
+
1143
+ if test_loader is not None:
1144
+ if isinstance(trainer.strategy, DDPStrategy):
1145
+ torch.distributed.destroy_process_group()
1146
+
1147
+ best_ckpt_path = trainer.checkpoint_callback.best_model_path
1148
+ trainer = pl.Trainer(
1149
+ logger=trainer_logger,
1150
+ enable_progress_bar=True,
1151
+ accelerator=args.accelerator,
1152
+ devices=1,
1153
+ )
1154
+ model = model.load_from_checkpoint(best_ckpt_path)
1155
+ predss = trainer.predict(model, dataloaders=test_loader)
1156
+ else:
1157
+ predss = trainer.predict(dataloaders=test_loader)
1158
+
1159
+ preds = torch.concat(predss, 0)
1160
+ if model.predictor.n_targets > 1:
1161
+ preds = preds[..., 0]
1162
+ preds = preds.numpy()
1163
+
1164
+ evaluate_and_save_predictions(
1165
+ preds, test_loader, model.metrics[:-1], model_output_dir, args
1166
+ )
1167
+
1168
+ best_model_path = checkpointing.best_model_path
1169
+ model = model.__class__.load_from_checkpoint(best_model_path)
1170
+ p_model = model_output_dir / "best.pt"
1171
+ save_model(p_model, model, args.target_columns)
1172
+ logger.info(f"Best model saved to '{p_model}'")
1173
+
1174
+ if args.remove_checkpoints:
1175
+ temp_dir.cleanup()
1176
+
1177
+
1178
+ def evaluate_and_save_predictions(preds, test_loader, metrics, model_output_dir, args):
1179
+ if isinstance(test_loader.dataset, MulticomponentDataset):
1180
+ test_dset = test_loader.dataset.datasets[0]
1181
+ else:
1182
+ test_dset = test_loader.dataset
1183
+ targets = test_dset.Y
1184
+ mask = torch.from_numpy(np.isfinite(targets))
1185
+ targets = np.nan_to_num(targets, nan=0.0)
1186
+ weights = torch.ones(len(test_dset))
1187
+ lt_mask = torch.from_numpy(test_dset.lt_mask) if test_dset.lt_mask[0] is not None else None
1188
+ gt_mask = torch.from_numpy(test_dset.gt_mask) if test_dset.gt_mask[0] is not None else None
1189
+
1190
+ individual_scores = dict()
1191
+ for metric in metrics:
1192
+ individual_scores[metric.alias] = []
1193
+ for i, col in enumerate(args.target_columns):
1194
+ if "multiclass" in args.task_type:
1195
+ preds_slice = torch.from_numpy(preds[:, i : i + 1, :])
1196
+ targets_slice = torch.from_numpy(targets[:, i : i + 1])
1197
+ else:
1198
+ preds_slice = torch.from_numpy(preds[:, i : i + 1])
1199
+ targets_slice = torch.from_numpy(targets[:, i : i + 1])
1200
+ preds_loss = metric(
1201
+ preds_slice,
1202
+ targets_slice,
1203
+ mask[:, i : i + 1],
1204
+ weights,
1205
+ lt_mask[:, i] if lt_mask is not None else None,
1206
+ gt_mask[:, i] if gt_mask is not None else None,
1207
+ )
1208
+ individual_scores[metric.alias].append(preds_loss)
1209
+
1210
+ logger.info("Test Set results:")
1211
+ for metric in metrics:
1212
+ avg_loss = sum(individual_scores[metric.alias]) / len(individual_scores[metric.alias])
1213
+ logger.info(f"test/{metric.alias}: {avg_loss}")
1214
+
1215
+ if args.show_individual_scores:
1216
+ logger.info("Entire Test Set individual results:")
1217
+ for metric in metrics:
1218
+ for i, col in enumerate(args.target_columns):
1219
+ logger.info(f"test/{col}/{metric.alias}: {individual_scores[metric.alias][i]}")
1220
+
1221
+ names = test_loader.dataset.names
1222
+ if isinstance(test_loader.dataset, MulticomponentDataset):
1223
+ namess = list(zip(*names))
1224
+ else:
1225
+ namess = [names]
1226
+
1227
+ columns = args.input_columns + args.target_columns
1228
+ if "multiclass" in args.task_type:
1229
+ columns = columns + [f"{col}_prob" for col in args.target_columns]
1230
+ formatted_probability_strings = np.apply_along_axis(
1231
+ lambda x: ",".join(map(str, x)), 2, preds
1232
+ )
1233
+ predicted_class_labels = preds.argmax(axis=-1)
1234
+ df_preds = pd.DataFrame(
1235
+ list(zip(*namess, *predicted_class_labels.T, *formatted_probability_strings.T)),
1236
+ columns=columns,
1237
+ )
1238
+ else:
1239
+ df_preds = pd.DataFrame(list(zip(*namess, *preds.T)), columns=columns)
1240
+ df_preds.to_csv(model_output_dir / "test_predictions.csv", index=False)
1241
+
1242
+
1243
+ def main(args):
1244
+ format_kwargs = dict(
1245
+ no_header_row=args.no_header_row,
1246
+ smiles_cols=args.smiles_columns,
1247
+ rxn_cols=args.reaction_columns,
1248
+ target_cols=args.target_columns,
1249
+ ignore_cols=args.ignore_columns,
1250
+ splits_col=args.splits_column,
1251
+ weight_col=args.weight_column,
1252
+ bounded=args.loss_function is not None and "bounded" in args.loss_function,
1253
+ )
1254
+
1255
+ featurization_kwargs = dict(
1256
+ molecule_featurizers=args.molecule_featurizers,
1257
+ keep_h=args.keep_h,
1258
+ add_h=args.add_h,
1259
+ ignore_chirality=args.ignore_chirality,
1260
+ )
1261
+
1262
+ splits = build_splits(args, format_kwargs, featurization_kwargs)
1263
+
1264
+ for replicate_idx, (train_data, val_data, test_data) in enumerate(zip(*splits)):
1265
+ if args.num_replicates == 1:
1266
+ output_dir = args.output_dir
1267
+ else:
1268
+ output_dir = args.output_dir / f"replicate_{replicate_idx}"
1269
+
1270
+ output_dir.mkdir(exist_ok=True, parents=True)
1271
+
1272
+ train_dset, val_dset, test_dset = build_datasets(args, train_data, val_data, test_data)
1273
+
1274
+ if args.save_smiles_splits:
1275
+ save_smiles_splits(args, output_dir, train_dset, val_dset, test_dset)
1276
+
1277
+ if args.checkpoint or args.model_frzn is not None:
1278
+ model_paths = find_models(args.checkpoint)
1279
+ if len(model_paths) > 1:
1280
+ logger.warning(
1281
+ "Multiple checkpoint files were loaded, but only the scalers from "
1282
+ f"{model_paths[0]} are used. It is assumed that all models provided have the "
1283
+ "same data scalings, meaning they were trained on the same data."
1284
+ )
1285
+ model_path = model_paths[0] if args.checkpoint else args.model_frzn
1286
+ load_and_use_pretrained_model_scalers(model_path, train_dset, val_dset)
1287
+ input_transforms = (None, None, None)
1288
+ output_transform = None
1289
+ else:
1290
+ input_transforms = normalize_inputs(train_dset, val_dset, args)
1291
+
1292
+ if "regression" in args.task_type:
1293
+ output_scaler = train_dset.normalize_targets()
1294
+ val_dset.normalize_targets(output_scaler)
1295
+ logger.info(
1296
+ f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}"
1297
+ )
1298
+ output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
1299
+ else:
1300
+ output_transform = None
1301
+
1302
+ if not args.no_cache:
1303
+ train_dset.cache = True
1304
+ val_dset.cache = True
1305
+
1306
+ train_loader = build_dataloader(
1307
+ train_dset,
1308
+ args.batch_size,
1309
+ args.num_workers,
1310
+ class_balance=args.class_balance,
1311
+ seed=args.data_seed,
1312
+ )
1313
+ if args.class_balance:
1314
+ logger.debug(
1315
+ f"With `--class-balance`, effective train size = {len(train_loader.sampler)}"
1316
+ )
1317
+ val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
1318
+ if test_dset is not None:
1319
+ test_loader = build_dataloader(
1320
+ test_dset, args.batch_size, args.num_workers, shuffle=False
1321
+ )
1322
+ else:
1323
+ test_loader = None
1324
+
1325
+ train_model(
1326
+ args,
1327
+ train_loader,
1328
+ val_loader,
1329
+ test_loader,
1330
+ output_dir,
1331
+ output_transform,
1332
+ input_transforms,
1333
+ )
1334
+
1335
+
1336
+ if __name__ == "__main__":
1337
+ # TODO: update this old code or remove it.
1338
+ parser = ArgumentParser()
1339
+ parser = TrainSubcommand.add_args(parser)
1340
+
1341
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
1342
+ args = parser.parse_args()
1343
+ TrainSubcommand.func(args)
chemprop-updated/chemprop/cli/utils/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .actions import LookupAction
2
+ from .args import bounded
3
+ from .command import Subcommand
4
+ from .parsing import (
5
+ build_data_from_files,
6
+ get_column_names,
7
+ make_datapoints,
8
+ make_dataset,
9
+ parse_indices,
10
+ )
11
+ from .utils import _pop_attr, _pop_attr_d, pop_attr
12
+
13
+ __all__ = [
14
+ "bounded",
15
+ "LookupAction",
16
+ "Subcommand",
17
+ "build_data_from_files",
18
+ "make_datapoints",
19
+ "make_dataset",
20
+ "get_column_names",
21
+ "parse_indices",
22
+ "actions",
23
+ "args",
24
+ "command",
25
+ "parsing",
26
+ "utils",
27
+ "pop_attr",
28
+ "_pop_attr",
29
+ "_pop_attr_d",
30
+ ]
chemprop-updated/chemprop/cli/utils/actions.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import _StoreAction
2
+ from typing import Any, Mapping
3
+
4
+
5
+ def LookupAction(obj: Mapping[str, Any]):
6
+ class LookupAction_(_StoreAction):
7
+ def __init__(self, option_strings, dest, default=None, choices=None, **kwargs):
8
+ if default not in obj.keys() and default is not None:
9
+ raise ValueError(
10
+ f"Invalid value for arg 'default': '{default}'. "
11
+ f"Expected one of {tuple(obj.keys())}"
12
+ )
13
+
14
+ kwargs["choices"] = choices if choices is not None else obj.keys()
15
+ kwargs["default"] = default
16
+
17
+ super().__init__(option_strings, dest, **kwargs)
18
+
19
+ return LookupAction_
chemprop-updated/chemprop/cli/utils/args.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ __all__ = ["bounded"]
4
+
5
+
6
+ def bounded(lo: float | None = None, hi: float | None = None):
7
+ if lo is None and hi is None:
8
+ raise ValueError("No bounds provided!")
9
+
10
+ def decorator(f):
11
+ @functools.wraps(f)
12
+ def wrapper(*args, **kwargs):
13
+ x = f(*args, **kwargs)
14
+
15
+ if (lo is not None and hi is not None) and not lo <= x <= hi:
16
+ raise ValueError(f"Parsed value outside of range [{lo}, {hi}]! got: {x}")
17
+ if hi is not None and x > hi:
18
+ raise ValueError(f"Parsed value below {hi}! got: {x}")
19
+ if lo is not None and x < lo:
20
+ raise ValueError(f"Parsed value above {lo}]! got: {x}")
21
+
22
+ return x
23
+
24
+ return wrapper
25
+
26
+ return decorator
27
+
28
+
29
+ def uppercase(x: str):
30
+ return x.upper()
31
+
32
+
33
+ def lowercase(x: str):
34
+ return x.lower()
chemprop-updated/chemprop/cli/utils/command.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from argparse import ArgumentParser, Namespace, _SubParsersAction
3
+
4
+
5
+ class Subcommand(ABC):
6
+ COMMAND: str
7
+ HELP: str | None = None
8
+
9
+ @classmethod
10
+ def add(cls, subparsers: _SubParsersAction, parents) -> ArgumentParser:
11
+ parser = subparsers.add_parser(cls.COMMAND, help=cls.HELP, parents=parents)
12
+ cls.add_args(parser).set_defaults(func=cls.func)
13
+
14
+ return parser
15
+
16
+ @classmethod
17
+ @abstractmethod
18
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
19
+ pass
20
+
21
+ @classmethod
22
+ @abstractmethod
23
+ def func(cls, args: Namespace):
24
+ pass
chemprop-updated/chemprop/cli/utils/parsing.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from os import PathLike
3
+ from typing import Literal, Mapping, Sequence
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
9
+ from chemprop.data.datasets import MoleculeDataset, ReactionDataset
10
+ from chemprop.featurizers.atom import get_multi_hot_atom_featurizer
11
+ from chemprop.featurizers.bond import MultiHotBondFeaturizer, RIGRBondFeaturizer
12
+ from chemprop.featurizers.molecule import MoleculeFeaturizerRegistry
13
+ from chemprop.featurizers.molgraph import (
14
+ CondensedGraphOfReactionFeaturizer,
15
+ SimpleMoleculeMolGraphFeaturizer,
16
+ )
17
+ from chemprop.utils import make_mol
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def parse_csv(
23
+ path: PathLike,
24
+ smiles_cols: Sequence[str] | None,
25
+ rxn_cols: Sequence[str] | None,
26
+ target_cols: Sequence[str] | None,
27
+ ignore_cols: Sequence[str] | None,
28
+ splits_col: str | None,
29
+ weight_col: str | None,
30
+ bounded: bool = False,
31
+ no_header_row: bool = False,
32
+ ):
33
+ df = pd.read_csv(path, header=None if no_header_row else "infer", index_col=False)
34
+
35
+ if smiles_cols is not None and rxn_cols is not None:
36
+ smiss = df[smiles_cols].T.values.tolist()
37
+ rxnss = df[rxn_cols].T.values.tolist()
38
+ input_cols = [*smiles_cols, *rxn_cols]
39
+ elif smiles_cols is not None and rxn_cols is None:
40
+ smiss = df[smiles_cols].T.values.tolist()
41
+ rxnss = None
42
+ input_cols = smiles_cols
43
+ elif smiles_cols is None and rxn_cols is not None:
44
+ smiss = None
45
+ rxnss = df[rxn_cols].T.values.tolist()
46
+ input_cols = rxn_cols
47
+ else:
48
+ smiss = df.iloc[:, [0]].T.values.tolist()
49
+ rxnss = None
50
+ input_cols = [df.columns[0]]
51
+
52
+ if target_cols is None:
53
+ target_cols = list(
54
+ column
55
+ for column in df.columns
56
+ if column
57
+ not in set( # if splits or weight is None, df.columns will never have None
58
+ input_cols + (ignore_cols or []) + [splits_col] + [weight_col]
59
+ )
60
+ )
61
+
62
+ Y = df[target_cols]
63
+ weights = None if weight_col is None else df[weight_col].to_numpy(np.single)
64
+
65
+ if bounded:
66
+ Y = Y.astype(str)
67
+ lt_mask = Y.applymap(lambda x: "<" in x).to_numpy()
68
+ gt_mask = Y.applymap(lambda x: ">" in x).to_numpy()
69
+ Y = Y.applymap(lambda x: x.strip("<").strip(">")).to_numpy(np.single)
70
+ else:
71
+ Y = Y.to_numpy(np.single)
72
+ lt_mask = None
73
+ gt_mask = None
74
+
75
+ return smiss, rxnss, Y, weights, lt_mask, gt_mask
76
+
77
+
78
+ def get_column_names(
79
+ path: PathLike,
80
+ smiles_cols: Sequence[str] | None,
81
+ rxn_cols: Sequence[str] | None,
82
+ target_cols: Sequence[str] | None,
83
+ ignore_cols: Sequence[str] | None,
84
+ splits_col: str | None,
85
+ weight_col: str | None,
86
+ no_header_row: bool = False,
87
+ ) -> tuple[list[str], list[str]]:
88
+ df_cols = pd.read_csv(path, index_col=False, nrows=0).columns.tolist()
89
+
90
+ if no_header_row:
91
+ return ["SMILES"], ["pred_" + str(i) for i in range((len(df_cols) - 1))]
92
+
93
+ input_cols = (smiles_cols or []) + (rxn_cols or [])
94
+
95
+ if len(input_cols) == 0:
96
+ input_cols = [df_cols[0]]
97
+
98
+ if target_cols is None:
99
+ target_cols = list(
100
+ column
101
+ for column in df_cols
102
+ if column
103
+ not in set(
104
+ input_cols + (ignore_cols or []) + ([splits_col] or []) + ([weight_col] or [])
105
+ )
106
+ )
107
+
108
+ return input_cols, target_cols
109
+
110
+
111
+ def make_datapoints(
112
+ smiss: list[list[str]] | None,
113
+ rxnss: list[list[str]] | None,
114
+ Y: np.ndarray,
115
+ weights: np.ndarray | None,
116
+ lt_mask: np.ndarray | None,
117
+ gt_mask: np.ndarray | None,
118
+ X_d: np.ndarray | None,
119
+ V_fss: list[list[np.ndarray] | list[None]] | None,
120
+ E_fss: list[list[np.ndarray] | list[None]] | None,
121
+ V_dss: list[list[np.ndarray] | list[None]] | None,
122
+ molecule_featurizers: list[str] | None,
123
+ keep_h: bool,
124
+ add_h: bool,
125
+ ignore_chirality: bool,
126
+ ) -> tuple[list[list[MoleculeDatapoint]], list[list[ReactionDatapoint]]]:
127
+ """Make the :class:`MoleculeDatapoint`s and :class:`ReactionDatapoint`s for a given
128
+ dataset.
129
+
130
+ Parameters
131
+ ----------
132
+ smiss : list[list[str]] | None
133
+ a list of ``j`` lists of ``n`` SMILES strings, where ``j`` is the number of molecules per
134
+ datapoint and ``n`` is the number of datapoints. If ``None``, the corresponding list of
135
+ :class:`MoleculeDatapoint`\s will be empty.
136
+ rxnss : list[list[str]] | None
137
+ a list of ``k`` lists of ``n`` reaction SMILES strings, where ``k`` is the number of
138
+ reactions per datapoint. If ``None``, the corresponding list of :class:`ReactionDatapoint`\s
139
+ will be empty.
140
+ Y : np.ndarray
141
+ the target values of shape ``n x m``, where ``m`` is the number of targets
142
+ weights : np.ndarray | None
143
+ the weights of the datapoints to use in the loss function of shape ``n x m``. If ``None``,
144
+ the weights all default to 1.
145
+ lt_mask : np.ndarray | None
146
+ a boolean mask of shape ``n x m`` indicating whether the targets are less than inequality
147
+ targets. If ``None``, ``lt_mask`` for all datapoints will be ``None``.
148
+ gt_mask : np.ndarray | None
149
+ a boolean mask of shape ``n x m`` indicating whether the targets are greater than inequality
150
+ targets. If ``None``, ``gt_mask`` for all datapoints will be ``None``.
151
+ X_d : np.ndarray | None
152
+ the extra descriptors of shape ``n x p``, where ``p`` is the number of extra descriptors. If
153
+ ``None``, ``x_d`` for all datapoints will be ``None``.
154
+ V_fss : list[list[np.ndarray] | list[None]] | None
155
+ a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x q_j``, where ``v_jn`` is
156
+ the number of atoms in the j-th molecule of the n-th datapoint and ``q_j`` is the number of
157
+ extra atom features used for the j-th molecules. Any of the ``j`` lists can be a list of
158
+ None values if the corresponding component does not use extra atom features. If ``None``,
159
+ ``V_f`` for all datapoints will be ``None``.
160
+ E_fss : list[list[np.ndarray] | list[None]] | None
161
+ a list of ``j`` lists of ``n`` np.ndarrays each of shape ``e_jn x r_j``, where ``e_jn`` is
162
+ the number of bonds in the j-th molecule of the n-th datapoint and ``r_j`` is the number of
163
+ extra bond features used for the j-th molecules. Any of the ``j`` lists can be a list of
164
+ None values if the corresponding component does not use extra bond features. If ``None``,
165
+ ``E_f`` for all datapoints will be ``None``.
166
+ V_dss : list[list[np.ndarray] | list[None]] | None
167
+ a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x s_j``, where ``s_j`` is
168
+ the number of extra atom descriptors used for the j-th molecules. Any of the ``j`` lists can
169
+ be a list of None values if the corresponding component does not use extra atom features. If
170
+ ``None``, ``V_d`` for all datapoints will be ``None``.
171
+ molecule_featurizers : list[str] | None
172
+ a list of molecule featurizer names to generate additional molecule features to use as extra
173
+ descriptors. If there are multiple molecules per datapoint, the featurizers will be applied
174
+ to each molecule and concatenated. Note that a :code:`ReactionDatapoint` has two
175
+ RDKit :class:`~rdkit.Chem.Mol` objects, reactant(s) and product(s). Each
176
+ ``molecule_featurizer`` will be applied to both of these objects.
177
+ keep_h : bool
178
+ whether to keep hydrogen atoms
179
+ add_h : bool
180
+ whether to add hydrogen atoms
181
+ ignore_chirality : bool
182
+ whether to ignore chirality information
183
+
184
+ Returns
185
+ -------
186
+ list[list[MoleculeDatapoint]]
187
+ a list of ``j`` lists of ``n`` :class:`MoleculeDatapoint`\s
188
+ list[list[ReactionDatapoint]]
189
+ a list of ``k`` lists of ``n`` :class:`ReactionDatapoint`\s
190
+ .. note::
191
+ either ``j`` or ``k`` may be 0, in which case the corresponding list will be empty.
192
+
193
+ Raises
194
+ ------
195
+ ValueError
196
+ if both ``smiss`` and ``rxnss`` are ``None``.
197
+ if ``smiss`` and ``rxnss`` are both given and have different lengths.
198
+ """
199
+ if smiss is None and rxnss is None:
200
+ raise ValueError("args 'smiss' and 'rnxss' were both `None`!")
201
+ elif rxnss is None:
202
+ N = len(smiss[0])
203
+ rxnss = []
204
+ elif smiss is None:
205
+ N = len(rxnss[0])
206
+ smiss = []
207
+ elif len(smiss[0]) != len(rxnss[0]):
208
+ raise ValueError(
209
+ f"args 'smiss' and 'rxnss' must have same length! got {len(smiss[0])} and {len(rxnss[0])}"
210
+ )
211
+ else:
212
+ N = len(smiss[0])
213
+
214
+ if len(smiss) > 0:
215
+ molss = [[make_mol(smi, keep_h, add_h, ignore_chirality) for smi in smis] for smis in smiss]
216
+ if len(rxnss) > 0:
217
+ rctss = [
218
+ [
219
+ make_mol(
220
+ f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi, keep_h, add_h, ignore_chirality
221
+ )
222
+ for rct_smi, agt_smi, _ in (rxn.split(">") for rxn in rxns)
223
+ ]
224
+ for rxns in rxnss
225
+ ]
226
+ pdtss = [
227
+ [
228
+ make_mol(pdt_smi, keep_h, add_h, ignore_chirality)
229
+ for _, _, pdt_smi in (rxn.split(">") for rxn in rxns)
230
+ ]
231
+ for rxns in rxnss
232
+ ]
233
+
234
+ weights = np.ones(N, dtype=np.single) if weights is None else weights
235
+ gt_mask = [None] * N if gt_mask is None else gt_mask
236
+ lt_mask = [None] * N if lt_mask is None else lt_mask
237
+
238
+ n_mols = len(smiss) if smiss else 0
239
+ V_fss = [[None] * N] * n_mols if V_fss is None else V_fss
240
+ E_fss = [[None] * N] * n_mols if E_fss is None else E_fss
241
+ V_dss = [[None] * N] * n_mols if V_dss is None else V_dss
242
+
243
+ if X_d is None and molecule_featurizers is None:
244
+ X_d = [None] * N
245
+ elif molecule_featurizers is None:
246
+ pass
247
+ else:
248
+ molecule_featurizers = [MoleculeFeaturizerRegistry[mf]() for mf in molecule_featurizers]
249
+
250
+ if len(smiss) > 0:
251
+ mol_descriptors = np.hstack(
252
+ [
253
+ np.vstack([np.hstack([mf(mol) for mf in molecule_featurizers]) for mol in mols])
254
+ for mols in molss
255
+ ]
256
+ )
257
+ if X_d is None:
258
+ X_d = mol_descriptors
259
+ else:
260
+ X_d = np.hstack([X_d, mol_descriptors])
261
+
262
+ if len(rxnss) > 0:
263
+ rct_pdt_descriptors = np.hstack(
264
+ [
265
+ np.vstack(
266
+ [
267
+ np.hstack(
268
+ [mf(mol) for mf in molecule_featurizers for mol in (rct, pdt)]
269
+ )
270
+ for rct, pdt in zip(rcts, pdts)
271
+ ]
272
+ )
273
+ for rcts, pdts in zip(rctss, pdtss)
274
+ ]
275
+ )
276
+ if X_d is None:
277
+ X_d = rct_pdt_descriptors
278
+ else:
279
+ X_d = np.hstack([X_d, rct_pdt_descriptors])
280
+
281
+ mol_data = [
282
+ [
283
+ MoleculeDatapoint(
284
+ mol=molss[mol_idx][i],
285
+ name=smis[i],
286
+ y=Y[i],
287
+ weight=weights[i],
288
+ gt_mask=gt_mask[i],
289
+ lt_mask=lt_mask[i],
290
+ x_d=X_d[i],
291
+ x_phase=None,
292
+ V_f=V_fss[mol_idx][i],
293
+ E_f=E_fss[mol_idx][i],
294
+ V_d=V_dss[mol_idx][i],
295
+ )
296
+ for i in range(N)
297
+ ]
298
+ for mol_idx, smis in enumerate(smiss)
299
+ ]
300
+ rxn_data = [
301
+ [
302
+ ReactionDatapoint(
303
+ rct=rctss[rxn_idx][i],
304
+ pdt=pdtss[rxn_idx][i],
305
+ name=rxns[i],
306
+ y=Y[i],
307
+ weight=weights[i],
308
+ gt_mask=gt_mask[i],
309
+ lt_mask=lt_mask[i],
310
+ x_d=X_d[i],
311
+ x_phase=None,
312
+ )
313
+ for i in range(N)
314
+ ]
315
+ for rxn_idx, rxns in enumerate(rxnss)
316
+ ]
317
+
318
+ return mol_data, rxn_data
319
+
320
+
321
+ def build_data_from_files(
322
+ p_data: PathLike,
323
+ no_header_row: bool,
324
+ smiles_cols: Sequence[str] | None,
325
+ rxn_cols: Sequence[str] | None,
326
+ target_cols: Sequence[str] | None,
327
+ ignore_cols: Sequence[str] | None,
328
+ splits_col: str | None,
329
+ weight_col: str | None,
330
+ bounded: bool,
331
+ p_descriptors: PathLike,
332
+ p_atom_feats: dict[int, PathLike],
333
+ p_bond_feats: dict[int, PathLike],
334
+ p_atom_descs: dict[int, PathLike],
335
+ **featurization_kwargs: Mapping,
336
+ ) -> list[list[MoleculeDatapoint] | list[ReactionDatapoint]]:
337
+ smiss, rxnss, Y, weights, lt_mask, gt_mask = parse_csv(
338
+ p_data,
339
+ smiles_cols,
340
+ rxn_cols,
341
+ target_cols,
342
+ ignore_cols,
343
+ splits_col,
344
+ weight_col,
345
+ bounded,
346
+ no_header_row,
347
+ )
348
+ n_molecules = len(smiss) if smiss is not None else 0
349
+ n_datapoints = len(Y)
350
+
351
+ X_ds = load_input_feats_and_descs(p_descriptors, None, None, feat_desc="X_d")
352
+ V_fss = load_input_feats_and_descs(p_atom_feats, n_molecules, n_datapoints, feat_desc="V_f")
353
+ E_fss = load_input_feats_and_descs(p_bond_feats, n_molecules, n_datapoints, feat_desc="E_f")
354
+ V_dss = load_input_feats_and_descs(p_atom_descs, n_molecules, n_datapoints, feat_desc="V_d")
355
+
356
+ mol_data, rxn_data = make_datapoints(
357
+ smiss,
358
+ rxnss,
359
+ Y,
360
+ weights,
361
+ lt_mask,
362
+ gt_mask,
363
+ X_ds,
364
+ V_fss,
365
+ E_fss,
366
+ V_dss,
367
+ **featurization_kwargs,
368
+ )
369
+
370
+ return mol_data + rxn_data
371
+
372
+
373
+ def load_input_feats_and_descs(
374
+ paths: dict[int, PathLike] | PathLike,
375
+ n_molecules: int | None,
376
+ n_datapoints: int | None,
377
+ feat_desc: str,
378
+ ):
379
+ if paths is None:
380
+ return None
381
+
382
+ match feat_desc:
383
+ case "X_d":
384
+ path = paths
385
+ loaded_feature = np.load(path)
386
+ features = loaded_feature["arr_0"]
387
+
388
+ case _:
389
+ for index in paths:
390
+ if index >= n_molecules:
391
+ raise ValueError(
392
+ f"For {n_molecules} molecules, atom/bond features/descriptors can only be "
393
+ f"specified for indices 0-{n_molecules - 1}! Got index {index}."
394
+ )
395
+
396
+ features = []
397
+ for idx in range(n_molecules):
398
+ path = paths.get(idx, None)
399
+
400
+ if path is not None:
401
+ loaded_feature = np.load(path)
402
+ loaded_feature = [
403
+ loaded_feature[f"arr_{i}"] for i in range(len(loaded_feature))
404
+ ]
405
+ else:
406
+ loaded_feature = [None] * n_datapoints
407
+
408
+ features.append(loaded_feature)
409
+ return features
410
+
411
+
412
+ def make_dataset(
413
+ data: Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint],
414
+ reaction_mode: str,
415
+ multi_hot_atom_featurizer_mode: Literal["V1", "V2", "ORGANIC", "RIGR"] = "V2",
416
+ ) -> MoleculeDataset | ReactionDataset:
417
+ atom_featurizer = get_multi_hot_atom_featurizer(multi_hot_atom_featurizer_mode)
418
+ match multi_hot_atom_featurizer_mode:
419
+ case "RIGR":
420
+ bond_featurizer = RIGRBondFeaturizer()
421
+ case "V1" | "V2" | "ORGANIC":
422
+ bond_featurizer = MultiHotBondFeaturizer()
423
+ case _:
424
+ raise TypeError(
425
+ f"Unsupported atom featurizer mode '{multi_hot_atom_featurizer_mode=}'!"
426
+ )
427
+
428
+ if isinstance(data[0], MoleculeDatapoint):
429
+ extra_atom_fdim = data[0].V_f.shape[1] if data[0].V_f is not None else 0
430
+ extra_bond_fdim = data[0].E_f.shape[1] if data[0].E_f is not None else 0
431
+ featurizer = SimpleMoleculeMolGraphFeaturizer(
432
+ atom_featurizer=atom_featurizer,
433
+ bond_featurizer=bond_featurizer,
434
+ extra_atom_fdim=extra_atom_fdim,
435
+ extra_bond_fdim=extra_bond_fdim,
436
+ )
437
+ return MoleculeDataset(data, featurizer)
438
+
439
+ featurizer = CondensedGraphOfReactionFeaturizer(
440
+ mode_=reaction_mode, atom_featurizer=atom_featurizer
441
+ )
442
+
443
+ return ReactionDataset(data, featurizer)
444
+
445
+
446
+ def parse_indices(idxs):
447
+ """Parses a string of indices into a list of integers. e.g. '0,1,2-4' -> [0, 1, 2, 3, 4]"""
448
+ if isinstance(idxs, str):
449
+ indices = []
450
+ for idx in idxs.split(","):
451
+ if "-" in idx:
452
+ start, end = map(int, idx.split("-"))
453
+ indices.extend(range(start, end + 1))
454
+ else:
455
+ indices.append(int(idx))
456
+ return indices
457
+ return idxs
chemprop-updated/chemprop/cli/utils/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ __all__ = ["pop_attr"]
4
+
5
+
6
+ def pop_attr(o: object, attr: str, *args) -> Any | None:
7
+ """like ``pop()`` but for attribute maps"""
8
+ match len(args):
9
+ case 0:
10
+ return _pop_attr(o, attr)
11
+ case 1:
12
+ return _pop_attr_d(o, attr, args[0])
13
+ case _:
14
+ raise TypeError(f"Expected at most 2 arguments! got: {len(args)}")
15
+
16
+
17
+ def _pop_attr(o: object, attr: str) -> Any:
18
+ val = getattr(o, attr)
19
+ delattr(o, attr)
20
+
21
+ return val
22
+
23
+
24
+ def _pop_attr_d(o: object, attr: str, default: Any | None = None) -> Any | None:
25
+ try:
26
+ val = getattr(o, attr)
27
+ delattr(o, attr)
28
+ except AttributeError:
29
+ val = default
30
+
31
+ return val
chemprop-updated/chemprop/conf.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Global configuration variables for chemprop"""
2
+
3
+ from chemprop.featurizers.molgraph.molecule import SimpleMoleculeMolGraphFeaturizer
4
+
5
+ DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM = SimpleMoleculeMolGraphFeaturizer().shape
6
+ DEFAULT_HIDDEN_DIM = 300
chemprop-updated/chemprop/data/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .collate import (
2
+ BatchMolGraph,
3
+ MulticomponentTrainingBatch,
4
+ TrainingBatch,
5
+ collate_batch,
6
+ collate_multicomponent,
7
+ )
8
+ from .dataloader import build_dataloader
9
+ from .datapoints import MoleculeDatapoint, ReactionDatapoint
10
+ from .datasets import (
11
+ Datum,
12
+ MoleculeDataset,
13
+ MolGraphDataset,
14
+ MulticomponentDataset,
15
+ ReactionDataset,
16
+ )
17
+ from .molgraph import MolGraph
18
+ from .samplers import ClassBalanceSampler, SeededSampler
19
+ from .splitting import SplitType, make_split_indices, split_data_by_indices
20
+
21
+ __all__ = [
22
+ "BatchMolGraph",
23
+ "TrainingBatch",
24
+ "collate_batch",
25
+ "MulticomponentTrainingBatch",
26
+ "collate_multicomponent",
27
+ "build_dataloader",
28
+ "MoleculeDatapoint",
29
+ "ReactionDatapoint",
30
+ "MoleculeDataset",
31
+ "ReactionDataset",
32
+ "Datum",
33
+ "MulticomponentDataset",
34
+ "MolGraphDataset",
35
+ "MolGraph",
36
+ "ClassBalanceSampler",
37
+ "SeededSampler",
38
+ "SplitType",
39
+ "make_split_indices",
40
+ "split_data_by_indices",
41
+ ]
chemprop-updated/chemprop/data/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.43 kB). View file
 
chemprop-updated/chemprop/data/__pycache__/data.cpython-37.pyc ADDED
Binary file (43.5 kB). View file
 
chemprop-updated/chemprop/data/__pycache__/scaffold.cpython-37.pyc ADDED
Binary file (7.37 kB). View file
 
chemprop-updated/chemprop/data/__pycache__/scaler.cpython-37.pyc ADDED
Binary file (5.95 kB). View file
 
chemprop-updated/chemprop/data/__pycache__/utils.cpython-37.pyc ADDED
Binary file (35.1 kB). View file
 
chemprop-updated/chemprop/data/collate.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import InitVar, dataclass, field
2
+ from typing import Iterable, NamedTuple, Sequence
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from chemprop.data.datasets import Datum
9
+ from chemprop.data.molgraph import MolGraph
10
+
11
+
12
+ @dataclass(repr=False, eq=False, slots=True)
13
+ class BatchMolGraph:
14
+ """A :class:`BatchMolGraph` represents a batch of individual :class:`MolGraph`\s.
15
+
16
+ It has all the attributes of a ``MolGraph`` with the addition of the ``batch`` attribute. This
17
+ class is intended for use with data loading, so it uses :obj:`~torch.Tensor`\s to store data
18
+ """
19
+
20
+ mgs: InitVar[Sequence[MolGraph]]
21
+ """A list of individual :class:`MolGraph`\s to be batched together"""
22
+ V: Tensor = field(init=False)
23
+ """the atom feature matrix"""
24
+ E: Tensor = field(init=False)
25
+ """the bond feature matrix"""
26
+ edge_index: Tensor = field(init=False)
27
+ """an tensor of shape ``2 x E`` containing the edges of the graph in COO format"""
28
+ rev_edge_index: Tensor = field(init=False)
29
+ """A tensor of shape ``E`` that maps from an edge index to the index of the source of the
30
+ reverse edge in the ``edge_index`` attribute."""
31
+ batch: Tensor = field(init=False)
32
+ """the index of the parent :class:`MolGraph` in the batched graph"""
33
+ names: list[str] = field(init=False) # Add SMILES strings for the batch
34
+
35
+ __size: int = field(init=False)
36
+
37
+ def __post_init__(self, mgs: Sequence[MolGraph]):
38
+ self.__size = len(mgs)
39
+
40
+ Vs = []
41
+ Es = []
42
+ edge_indexes = []
43
+ rev_edge_indexes = []
44
+ batch_indexes = []
45
+ self.names = []
46
+
47
+ num_nodes = 0
48
+ num_edges = 0
49
+ for i, mg in enumerate(mgs):
50
+ Vs.append(mg.V)
51
+ Es.append(mg.E)
52
+ edge_indexes.append(mg.edge_index + num_nodes)
53
+ rev_edge_indexes.append(mg.rev_edge_index + num_edges)
54
+ batch_indexes.append([i] * len(mg.V))
55
+ self.names.append(mg.name)
56
+
57
+ num_nodes += mg.V.shape[0]
58
+ num_edges += mg.edge_index.shape[1]
59
+
60
+ self.V = torch.from_numpy(np.concatenate(Vs)).float()
61
+ self.E = torch.from_numpy(np.concatenate(Es)).float()
62
+ self.edge_index = torch.from_numpy(np.hstack(edge_indexes)).long()
63
+ self.rev_edge_index = torch.from_numpy(np.concatenate(rev_edge_indexes)).long()
64
+ self.batch = torch.tensor(np.concatenate(batch_indexes)).long()
65
+
66
+ def __len__(self) -> int:
67
+ """the number of individual :class:`MolGraph`\s in this batch"""
68
+ return self.__size
69
+
70
+ def to(self, device: str | torch.device):
71
+ self.V = self.V.to(device)
72
+ self.E = self.E.to(device)
73
+ self.edge_index = self.edge_index.to(device)
74
+ self.rev_edge_index = self.rev_edge_index.to(device)
75
+ self.batch = self.batch.to(device)
76
+
77
+
78
+ class TrainingBatch(NamedTuple):
79
+ bmg: BatchMolGraph
80
+ V_d: Tensor | None
81
+ X_d: Tensor | None
82
+ Y: Tensor | None
83
+ w: Tensor
84
+ lt_mask: Tensor | None
85
+ gt_mask: Tensor | None
86
+
87
+
88
+ def collate_batch(batch: Iterable[Datum]) -> TrainingBatch:
89
+ mgs, V_ds, x_ds, ys, weights, lt_masks, gt_masks = zip(*batch)
90
+
91
+ return TrainingBatch(
92
+ BatchMolGraph(mgs),
93
+ None if V_ds[0] is None else torch.from_numpy(np.concatenate(V_ds)).float(),
94
+ None if x_ds[0] is None else torch.from_numpy(np.array(x_ds)).float(),
95
+ None if ys[0] is None else torch.from_numpy(np.array(ys)).float(),
96
+ torch.tensor(weights, dtype=torch.float).unsqueeze(1),
97
+ None if lt_masks[0] is None else torch.from_numpy(np.array(lt_masks)),
98
+ None if gt_masks[0] is None else torch.from_numpy(np.array(gt_masks)),
99
+ )
100
+
101
+
102
+ class MulticomponentTrainingBatch(NamedTuple):
103
+ bmgs: list[BatchMolGraph]
104
+ V_ds: list[Tensor | None]
105
+ X_d: Tensor | None
106
+ Y: Tensor | None
107
+ w: Tensor
108
+ lt_mask: Tensor | None
109
+ gt_mask: Tensor | None
110
+
111
+
112
+ def collate_multicomponent(batches: Iterable[Iterable[Datum]]) -> MulticomponentTrainingBatch:
113
+ tbs = [collate_batch(batch) for batch in zip(*batches)]
114
+
115
+ return MulticomponentTrainingBatch(
116
+ [tb.bmg for tb in tbs],
117
+ [tb.V_d for tb in tbs],
118
+ tbs[0].X_d,
119
+ tbs[0].Y,
120
+ tbs[0].w,
121
+ tbs[0].lt_mask,
122
+ tbs[0].gt_mask,
123
+ )
chemprop-updated/chemprop/data/dataloader.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ from torch.utils.data import DataLoader
4
+
5
+ from chemprop.data.collate import collate_batch, collate_multicomponent
6
+ from chemprop.data.datasets import MoleculeDataset, MulticomponentDataset, ReactionDataset
7
+ from chemprop.data.samplers import ClassBalanceSampler, SeededSampler
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def build_dataloader(
13
+ dataset: MoleculeDataset | ReactionDataset | MulticomponentDataset,
14
+ batch_size: int = 64,
15
+ num_workers: int = 0,
16
+ class_balance: bool = False,
17
+ seed: int | None = None,
18
+ shuffle: bool = True,
19
+ **kwargs,
20
+ ):
21
+ """Return a :obj:`~torch.utils.data.DataLoader` for :class:`MolGraphDataset`\s
22
+
23
+ Parameters
24
+ ----------
25
+ dataset : MoleculeDataset | ReactionDataset | MulticomponentDataset
26
+ The dataset containing the molecules or reactions to load.
27
+ batch_size : int, default=64
28
+ the batch size to load.
29
+ num_workers : int, default=0
30
+ the number of workers used to build batches.
31
+ class_balance : bool, default=False
32
+ Whether to perform class balancing (i.e., use an equal number of positive and negative
33
+ molecules). Class balance is only available for single task classification datasets. Set
34
+ shuffle to True in order to get a random subset of the larger class.
35
+ seed : int, default=None
36
+ the random seed to use for shuffling (only used when `shuffle` is `True`).
37
+ shuffle : bool, default=False
38
+ whether to shuffle the data during sampling.
39
+ """
40
+
41
+ if class_balance:
42
+ sampler = ClassBalanceSampler(dataset.Y, seed, shuffle)
43
+ elif shuffle and seed is not None:
44
+ sampler = SeededSampler(len(dataset), seed)
45
+ else:
46
+ sampler = None
47
+
48
+ if isinstance(dataset, MulticomponentDataset):
49
+ collate_fn = collate_multicomponent
50
+ else:
51
+ collate_fn = collate_batch
52
+
53
+ if len(dataset) % batch_size == 1:
54
+ logger.warning(
55
+ f"Dropping last batch of size 1 to avoid issues with batch normalization \
56
+ (dataset size = {len(dataset)}, batch_size = {batch_size})"
57
+ )
58
+ drop_last = True
59
+ else:
60
+ drop_last = False
61
+
62
+ return DataLoader(
63
+ dataset,
64
+ batch_size,
65
+ sampler is None and shuffle,
66
+ sampler,
67
+ num_workers=num_workers,
68
+ collate_fn=collate_fn,
69
+ drop_last=drop_last,
70
+ **kwargs,
71
+ )
chemprop-updated/chemprop/data/datapoints.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+ import numpy as np
6
+ from rdkit.Chem import AllChem as Chem
7
+
8
+ from chemprop.featurizers import Featurizer
9
+ from chemprop.utils import make_mol
10
+
11
+ MoleculeFeaturizer = Featurizer[Chem.Mol, np.ndarray]
12
+
13
+
14
+ @dataclass(slots=True)
15
+ class _DatapointMixin:
16
+ """A mixin class for both molecule- and reaction- and multicomponent-type data"""
17
+
18
+ y: np.ndarray | None = None
19
+ """the targets for the molecule with unknown targets indicated by `nan`s"""
20
+ weight: float = 1.0
21
+ """the weight of this datapoint for the loss calculation."""
22
+ gt_mask: np.ndarray | None = None
23
+ """Indicates whether the targets are an inequality regression target of the form `<x`"""
24
+ lt_mask: np.ndarray | None = None
25
+ """Indicates whether the targets are an inequality regression target of the form `>x`"""
26
+ x_d: np.ndarray | None = None
27
+ """A vector of length ``d_f`` containing additional features (e.g., Morgan fingerprint) that
28
+ will be concatenated to the global representation *after* aggregation"""
29
+ x_phase: list[float] = None
30
+ """A one-hot vector indicating the phase of the data, as used in spectra data."""
31
+ name: str | None = None
32
+ """A string identifier for the datapoint."""
33
+
34
+ def __post_init__(self):
35
+ NAN_TOKEN = 0
36
+ if self.x_d is not None:
37
+ self.x_d[np.isnan(self.x_d)] = NAN_TOKEN
38
+
39
+ @property
40
+ def t(self) -> int | None:
41
+ return len(self.y) if self.y is not None else None
42
+
43
+
44
+ @dataclass
45
+ class _MoleculeDatapointMixin:
46
+ mol: Chem.Mol
47
+ """the molecule associated with this datapoint"""
48
+
49
+ @classmethod
50
+ def from_smi(
51
+ cls,
52
+ smi: str,
53
+ *args,
54
+ keep_h: bool = False,
55
+ add_h: bool = False,
56
+ ignore_chirality: bool = False,
57
+ **kwargs,
58
+ ) -> _MoleculeDatapointMixin:
59
+ mol = make_mol(smi, keep_h, add_h, ignore_chirality)
60
+
61
+ kwargs["name"] = smi if "name" not in kwargs else kwargs["name"]
62
+
63
+ return cls(mol, *args, **kwargs)
64
+
65
+
66
+ @dataclass
67
+ class MoleculeDatapoint(_DatapointMixin, _MoleculeDatapointMixin):
68
+ """A :class:`MoleculeDatapoint` contains a single molecule and its associated features and targets."""
69
+
70
+ V_f: np.ndarray | None = None
71
+ """a numpy array of shape ``V x d_vf``, where ``V`` is the number of atoms in the molecule, and
72
+ ``d_vf`` is the number of additional features that will be concatenated to atom-level features
73
+ *before* message passing"""
74
+ E_f: np.ndarray | None = None
75
+ """A numpy array of shape ``E x d_ef``, where ``E`` is the number of bonds in the molecule, and
76
+ ``d_ef`` is the number of additional features containing additional features that will be
77
+ concatenated to bond-level features *before* message passing"""
78
+ V_d: np.ndarray | None = None
79
+ """A numpy array of shape ``V x d_vd``, where ``V`` is the number of atoms in the molecule, and
80
+ ``d_vd`` is the number of additional descriptors that will be concatenated to atom-level
81
+ descriptors *after* message passing"""
82
+
83
+ def __post_init__(self):
84
+ NAN_TOKEN = 0
85
+ if self.V_f is not None:
86
+ self.V_f[np.isnan(self.V_f)] = NAN_TOKEN
87
+ if self.E_f is not None:
88
+ self.E_f[np.isnan(self.E_f)] = NAN_TOKEN
89
+ if self.V_d is not None:
90
+ self.V_d[np.isnan(self.V_d)] = NAN_TOKEN
91
+
92
+ super().__post_init__()
93
+
94
+ def __len__(self) -> int:
95
+ return 1
96
+
97
+
98
+ @dataclass
99
+ class _ReactionDatapointMixin:
100
+ rct: Chem.Mol
101
+ """the reactant associated with this datapoint"""
102
+ pdt: Chem.Mol
103
+ """the product associated with this datapoint"""
104
+
105
+ @classmethod
106
+ def from_smi(
107
+ cls,
108
+ rxn_or_smis: str | tuple[str, str],
109
+ *args,
110
+ keep_h: bool = False,
111
+ add_h: bool = False,
112
+ ignore_chirality: bool = False,
113
+ **kwargs,
114
+ ) -> _ReactionDatapointMixin:
115
+ match rxn_or_smis:
116
+ case str():
117
+ rct_smi, agt_smi, pdt_smi = rxn_or_smis.split(">")
118
+ rct_smi = f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi
119
+ name = rxn_or_smis
120
+ case tuple():
121
+ rct_smi, pdt_smi = rxn_or_smis
122
+ name = ">>".join(rxn_or_smis)
123
+ case _:
124
+ raise TypeError(
125
+ "Must provide either a reaction SMARTS string or a tuple of reactant and"
126
+ " a product SMILES strings!"
127
+ )
128
+
129
+ rct = make_mol(rct_smi, keep_h, add_h, ignore_chirality)
130
+ pdt = make_mol(pdt_smi, keep_h, add_h, ignore_chirality)
131
+
132
+ kwargs["name"] = name if "name" not in kwargs else kwargs["name"]
133
+
134
+ return cls(rct, pdt, *args, **kwargs)
135
+
136
+
137
+ @dataclass
138
+ class ReactionDatapoint(_DatapointMixin, _ReactionDatapointMixin):
139
+ """A :class:`ReactionDatapoint` contains a single reaction and its associated features and targets."""
140
+
141
+ def __post_init__(self):
142
+ if self.rct is None:
143
+ raise ValueError("Reactant cannot be `None`!")
144
+ if self.pdt is None:
145
+ raise ValueError("Product cannot be `None`!")
146
+
147
+ return super().__post_init__()
148
+
149
+ def __len__(self) -> int:
150
+ return 2
chemprop-updated/chemprop/data/datasets.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from functools import cached_property
3
+ from typing import NamedTuple, TypeAlias
4
+
5
+ import numpy as np
6
+ from numpy.typing import ArrayLike
7
+ from rdkit import Chem
8
+ from rdkit.Chem import Mol
9
+ from sklearn.preprocessing import StandardScaler
10
+ from torch.utils.data import Dataset
11
+
12
+ from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
13
+ from chemprop.data.molgraph import MolGraph
14
+ from chemprop.featurizers.base import Featurizer
15
+ from chemprop.featurizers.molgraph import CGRFeaturizer, SimpleMoleculeMolGraphFeaturizer
16
+ from chemprop.featurizers.molgraph.cache import MolGraphCache, MolGraphCacheOnTheFly
17
+ from chemprop.types import Rxn
18
+
19
+
20
+ class Datum(NamedTuple):
21
+ """a singular training data point"""
22
+
23
+ mg: MolGraph
24
+ V_d: np.ndarray | None
25
+ x_d: np.ndarray | None
26
+ y: np.ndarray | None
27
+ weight: float
28
+ lt_mask: np.ndarray | None
29
+ gt_mask: np.ndarray | None
30
+
31
+
32
+ MolGraphDataset: TypeAlias = Dataset[Datum]
33
+
34
+
35
+ class _MolGraphDatasetMixin:
36
+ def __len__(self) -> int:
37
+ return len(self.data)
38
+
39
+ @cached_property
40
+ def _Y(self) -> np.ndarray:
41
+ """the raw targets of the dataset"""
42
+ return np.array([d.y for d in self.data], float)
43
+
44
+ @property
45
+ def Y(self) -> np.ndarray:
46
+ """the (scaled) targets of the dataset"""
47
+ return self.__Y
48
+
49
+ @Y.setter
50
+ def Y(self, Y: ArrayLike):
51
+ self._validate_attribute(Y, "targets")
52
+
53
+ self.__Y = np.array(Y, float)
54
+
55
+ @cached_property
56
+ def _X_d(self) -> np.ndarray:
57
+ """the raw extra descriptors of the dataset"""
58
+ return np.array([d.x_d for d in self.data])
59
+
60
+ @property
61
+ def X_d(self) -> np.ndarray:
62
+ """the (scaled) extra descriptors of the dataset"""
63
+ return self.__X_d
64
+
65
+ @X_d.setter
66
+ def X_d(self, X_d: ArrayLike):
67
+ self._validate_attribute(X_d, "extra descriptors")
68
+
69
+ self.__X_d = np.array(X_d)
70
+
71
+ @property
72
+ def weights(self) -> np.ndarray:
73
+ return np.array([d.weight for d in self.data])
74
+
75
+ @property
76
+ def gt_mask(self) -> np.ndarray:
77
+ return np.array([d.gt_mask for d in self.data])
78
+
79
+ @property
80
+ def lt_mask(self) -> np.ndarray:
81
+ return np.array([d.lt_mask for d in self.data])
82
+
83
+ @property
84
+ def t(self) -> int | None:
85
+ return self.data[0].t if len(self.data) > 0 else None
86
+
87
+ @property
88
+ def d_xd(self) -> int:
89
+ """the extra molecule descriptor dimension, if any"""
90
+ return 0 if self.X_d[0] is None else self.X_d.shape[1]
91
+
92
+ @property
93
+ def names(self) -> list[str]:
94
+ return [d.name for d in self.data]
95
+
96
+ def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler:
97
+ """Normalizes the targets of this dataset using a :obj:`StandardScaler`
98
+
99
+ The :obj:`StandardScaler` subtracts the mean and divides by the standard deviation for
100
+ each task independently. NOTE: This should only be used for regression datasets.
101
+
102
+ Returns
103
+ -------
104
+ StandardScaler
105
+ a scaler fit to the targets.
106
+ """
107
+
108
+ if scaler is None:
109
+ scaler = StandardScaler().fit(self._Y)
110
+
111
+ self.Y = scaler.transform(self._Y)
112
+
113
+ return scaler
114
+
115
+ def normalize_inputs(
116
+ self, key: str = "X_d", scaler: StandardScaler | None = None
117
+ ) -> StandardScaler:
118
+ VALID_KEYS = {"X_d"}
119
+ if key not in VALID_KEYS:
120
+ raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}")
121
+
122
+ X = self.X_d if self.X_d[0] is not None else None
123
+
124
+ if X is None:
125
+ return scaler
126
+
127
+ if scaler is None:
128
+ scaler = StandardScaler().fit(X)
129
+
130
+ self.X_d = scaler.transform(X)
131
+
132
+ return scaler
133
+
134
+ def reset(self):
135
+ """Reset the atom and bond features; atom and extra descriptors; and targets of each
136
+ datapoint to their initial, unnormalized values."""
137
+ self.__Y = self._Y
138
+ self.__X_d = self._X_d
139
+
140
+ def _validate_attribute(self, X: np.ndarray, label: str):
141
+ if not len(self.data) == len(X):
142
+ raise ValueError(
143
+ f"number of molecules ({len(self.data)}) and {label} ({len(X)}) "
144
+ "must have same length!"
145
+ )
146
+
147
+
148
+ @dataclass
149
+ class MoleculeDataset(_MolGraphDatasetMixin, MolGraphDataset):
150
+ """A :class:`MoleculeDataset` composed of :class:`MoleculeDatapoint`\s
151
+
152
+ A :class:`MoleculeDataset` produces featurized data for input to a
153
+ :class:`MPNN` model. Typically, data featurization is performed on-the-fly
154
+ and parallelized across multiple workers via the :class:`~torch.utils.data
155
+ DataLoader` class. However, for small datasets, it may be more efficient to
156
+ featurize the data in advance and cache the results. This can be done by
157
+ setting ``MoleculeDataset.cache=True``.
158
+
159
+ Parameters
160
+ ----------
161
+ data : Iterable[MoleculeDatapoint]
162
+ the data from which to create a dataset
163
+ featurizer : MoleculeFeaturizer
164
+ the featurizer with which to generate MolGraphs of the molecules
165
+ """
166
+
167
+ data: list[MoleculeDatapoint]
168
+ featurizer: Featurizer[Mol, MolGraph] = field(default_factory=SimpleMoleculeMolGraphFeaturizer)
169
+
170
+ def __post_init__(self):
171
+ if self.data is None:
172
+ raise ValueError("Data cannot be None!")
173
+
174
+ self.reset()
175
+ self.cache = False
176
+
177
+ def __getitem__(self, idx: int) -> Datum:
178
+ d = self.data[idx]
179
+ mg = self.mg_cache[idx]
180
+
181
+ # Assign the SMILES string to the MolGraph
182
+ mg_with_name = MolGraph(
183
+ V=mg.V,
184
+ E=mg.E,
185
+ edge_index=mg.edge_index,
186
+ rev_edge_index=mg.rev_edge_index,
187
+ name=d.name # Assign the SMILES string
188
+ )
189
+
190
+ return Datum(
191
+ mg=mg_with_name, # Use the updated MolGraph
192
+ V_d=self.V_ds[idx],
193
+ x_d=self.X_d[idx],
194
+ y=self.Y[idx],
195
+ weight=d.weight,
196
+ lt_mask=d.lt_mask,
197
+ gt_mask=d.gt_mask,
198
+ )
199
+ @property
200
+ def cache(self) -> bool:
201
+ return self.__cache
202
+
203
+ @cache.setter
204
+ def cache(self, cache: bool = False):
205
+ self.__cache = cache
206
+ self._init_cache()
207
+
208
+ def _init_cache(self):
209
+ """initialize the cache"""
210
+ self.mg_cache = (MolGraphCache if self.cache else MolGraphCacheOnTheFly)(
211
+ self.mols, self.V_fs, self.E_fs, self.featurizer
212
+ )
213
+
214
+ @property
215
+ def smiles(self) -> list[str]:
216
+ """the SMILES strings associated with the dataset"""
217
+ return [Chem.MolToSmiles(d.mol) for d in self.data]
218
+
219
+ @property
220
+ def mols(self) -> list[Chem.Mol]:
221
+ """the molecules associated with the dataset"""
222
+ return [d.mol for d in self.data]
223
+
224
+ @property
225
+ def _V_fs(self) -> list[np.ndarray]:
226
+ """the raw atom features of the dataset"""
227
+ return [d.V_f for d in self.data]
228
+
229
+ @property
230
+ def V_fs(self) -> list[np.ndarray]:
231
+ """the (scaled) atom descriptors of the dataset"""
232
+ return self.__V_fs
233
+
234
+ @V_fs.setter
235
+ def V_fs(self, V_fs: list[np.ndarray]):
236
+ """the (scaled) atom features of the dataset"""
237
+ self._validate_attribute(V_fs, "atom features")
238
+
239
+ self.__V_fs = V_fs
240
+ self._init_cache()
241
+
242
+ @property
243
+ def _E_fs(self) -> list[np.ndarray]:
244
+ """the raw bond features of the dataset"""
245
+ return [d.E_f for d in self.data]
246
+
247
+ @property
248
+ def E_fs(self) -> list[np.ndarray]:
249
+ """the (scaled) bond features of the dataset"""
250
+ return self.__E_fs
251
+
252
+ @E_fs.setter
253
+ def E_fs(self, E_fs: list[np.ndarray]):
254
+ self._validate_attribute(E_fs, "bond features")
255
+
256
+ self.__E_fs = E_fs
257
+ self._init_cache()
258
+
259
+ @property
260
+ def _V_ds(self) -> list[np.ndarray]:
261
+ """the raw atom descriptors of the dataset"""
262
+ return [d.V_d for d in self.data]
263
+
264
+ @property
265
+ def V_ds(self) -> list[np.ndarray]:
266
+ """the (scaled) atom descriptors of the dataset"""
267
+ return self.__V_ds
268
+
269
+ @V_ds.setter
270
+ def V_ds(self, V_ds: list[np.ndarray]):
271
+ self._validate_attribute(V_ds, "atom descriptors")
272
+
273
+ self.__V_ds = V_ds
274
+
275
+ @property
276
+ def d_vf(self) -> int:
277
+ """the extra atom feature dimension, if any"""
278
+ return 0 if self.V_fs[0] is None else self.V_fs[0].shape[1]
279
+
280
+ @property
281
+ def d_ef(self) -> int:
282
+ """the extra bond feature dimension, if any"""
283
+ return 0 if self.E_fs[0] is None else self.E_fs[0].shape[1]
284
+
285
+ @property
286
+ def d_vd(self) -> int:
287
+ """the extra atom descriptor dimension, if any"""
288
+ return 0 if self.V_ds[0] is None else self.V_ds[0].shape[1]
289
+
290
+ def normalize_inputs(
291
+ self, key: str = "X_d", scaler: StandardScaler | None = None
292
+ ) -> StandardScaler:
293
+ VALID_KEYS = {"X_d", "V_f", "E_f", "V_d"}
294
+
295
+ match key:
296
+ case "X_d":
297
+ X = None if self.d_xd == 0 else self.X_d
298
+ case "V_f":
299
+ X = None if self.d_vf == 0 else np.concatenate(self.V_fs, axis=0)
300
+ case "E_f":
301
+ X = None if self.d_ef == 0 else np.concatenate(self.E_fs, axis=0)
302
+ case "V_d":
303
+ X = None if self.d_vd == 0 else np.concatenate(self.V_ds, axis=0)
304
+ case _:
305
+ raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}")
306
+
307
+ if X is None:
308
+ return scaler
309
+
310
+ if scaler is None:
311
+ scaler = StandardScaler().fit(X)
312
+
313
+ match key:
314
+ case "X_d":
315
+ self.X_d = scaler.transform(X)
316
+ case "V_f":
317
+ self.V_fs = [scaler.transform(V_f) if V_f.size > 0 else V_f for V_f in self.V_fs]
318
+ case "E_f":
319
+ self.E_fs = [scaler.transform(E_f) if E_f.size > 0 else E_f for E_f in self.E_fs]
320
+ case "V_d":
321
+ self.V_ds = [scaler.transform(V_d) if V_d.size > 0 else V_d for V_d in self.V_ds]
322
+ case _:
323
+ raise RuntimeError("unreachable code reached!")
324
+
325
+ return scaler
326
+
327
+ def reset(self):
328
+ """Reset the atom and bond features; atom and extra descriptors; and targets of each
329
+ datapoint to their initial, unnormalized values."""
330
+ super().reset()
331
+ self.__V_fs = self._V_fs
332
+ self.__E_fs = self._E_fs
333
+ self.__V_ds = self._V_ds
334
+
335
+
336
+ @dataclass
337
+ class ReactionDataset(_MolGraphDatasetMixin, MolGraphDataset):
338
+ """A :class:`ReactionDataset` composed of :class:`ReactionDatapoint`\s
339
+
340
+ .. note::
341
+ The featurized data provided by this class may be cached, simlar to a
342
+ :class:`MoleculeDataset`. To enable the cache, set ``ReactionDataset
343
+ cache=True``.
344
+ """
345
+
346
+ data: list[ReactionDatapoint]
347
+ """the dataset from which to load"""
348
+ featurizer: Featurizer[Rxn, MolGraph] = field(default_factory=CGRFeaturizer)
349
+ """the featurizer with which to generate MolGraphs of the input"""
350
+
351
+ def __post_init__(self):
352
+ if self.data is None:
353
+ raise ValueError("Data cannot be None!")
354
+
355
+ self.reset()
356
+ self.cache = False
357
+
358
+ @property
359
+ def cache(self) -> bool:
360
+ return self.__cache
361
+
362
+ @cache.setter
363
+ def cache(self, cache: bool = False):
364
+ self.__cache = cache
365
+ self.mg_cache = (MolGraphCache if cache else MolGraphCacheOnTheFly)(
366
+ self.mols, [None] * len(self), [None] * len(self), self.featurizer
367
+ )
368
+
369
+ def __getitem__(self, idx: int) -> Datum:
370
+ d = self.data[idx]
371
+ mg = self.mg_cache[idx]
372
+
373
+ return Datum(mg, None, self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
374
+
375
+ @property
376
+ def smiles(self) -> list[tuple]:
377
+ return [(Chem.MolToSmiles(d.rct), Chem.MolToSmiles(d.pdt)) for d in self.data]
378
+
379
+ @property
380
+ def mols(self) -> list[Rxn]:
381
+ return [(d.rct, d.pdt) for d in self.data]
382
+
383
+ @property
384
+ def d_vf(self) -> int:
385
+ return 0
386
+
387
+ @property
388
+ def d_ef(self) -> int:
389
+ return 0
390
+
391
+ @property
392
+ def d_vd(self) -> int:
393
+ return 0
394
+
395
+
396
+ @dataclass(repr=False, eq=False)
397
+ class MulticomponentDataset(_MolGraphDatasetMixin, Dataset):
398
+ """A :class:`MulticomponentDataset` is a :class:`Dataset` composed of parallel
399
+ :class:`MoleculeDatasets` and :class:`ReactionDataset`\s"""
400
+
401
+ datasets: list[MoleculeDataset | ReactionDataset]
402
+ """the parallel datasets"""
403
+
404
+ def __post_init__(self):
405
+ sizes = [len(dset) for dset in self.datasets]
406
+ if not all(sizes[0] == size for size in sizes[1:]):
407
+ raise ValueError(f"Datasets must have all same length! got: {sizes}")
408
+
409
+ def __len__(self) -> int:
410
+ return len(self.datasets[0])
411
+
412
+ @property
413
+ def n_components(self) -> int:
414
+ return len(self.datasets)
415
+
416
+ def __getitem__(self, idx: int) -> list[Datum]:
417
+ return [dset[idx] for dset in self.datasets]
418
+
419
+ @property
420
+ def smiles(self) -> list[list[str]]:
421
+ return list(zip(*[dset.smiles for dset in self.datasets]))
422
+
423
+ @property
424
+ def names(self) -> list[list[str]]:
425
+ return list(zip(*[dset.names for dset in self.datasets]))
426
+
427
+ @property
428
+ def mols(self) -> list[list[Chem.Mol]]:
429
+ return list(zip(*[dset.mols for dset in self.datasets]))
430
+
431
+ def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler:
432
+ return self.datasets[0].normalize_targets(scaler)
433
+
434
+ def normalize_inputs(
435
+ self, key: str = "X_d", scaler: list[StandardScaler] | None = None
436
+ ) -> list[StandardScaler]:
437
+ RXN_VALID_KEYS = {"X_d"}
438
+ match scaler:
439
+ case None:
440
+ return [
441
+ dset.normalize_inputs(key)
442
+ if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS
443
+ else None
444
+ for dset in self.datasets
445
+ ]
446
+ case _:
447
+ assert len(scaler) == len(
448
+ self.datasets
449
+ ), "Number of scalers must match number of datasets!"
450
+
451
+ return [
452
+ dset.normalize_inputs(key, s)
453
+ if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS
454
+ else None
455
+ for dset, s in zip(self.datasets, scaler)
456
+ ]
457
+
458
+ def reset(self):
459
+ return [dset.reset() for dset in self.datasets]
460
+
461
+ @property
462
+ def d_xd(self) -> list[int]:
463
+ return self.datasets[0].d_xd
464
+
465
+ @property
466
+ def d_vf(self) -> list[int]:
467
+ return sum(dset.d_vf for dset in self.datasets)
468
+
469
+ @property
470
+ def d_ef(self) -> list[int]:
471
+ return sum(dset.d_ef for dset in self.datasets)
472
+
473
+ @property
474
+ def d_vd(self) -> list[int]:
475
+ return sum(dset.d_vd for dset in self.datasets)
chemprop-updated/chemprop/data/molgraph.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+
3
+ import numpy as np
4
+
5
+
6
+ class MolGraph(NamedTuple):
7
+ """A :class:`MolGraph` represents the graph featurization of a molecule."""
8
+
9
+ V: np.ndarray
10
+ """an array of shape ``V x d_v`` containing the atom features of the molecule"""
11
+ E: np.ndarray
12
+ """an array of shape ``E x d_e`` containing the bond features of the molecule"""
13
+ edge_index: np.ndarray
14
+ """an array of shape ``2 x E`` containing the edges of the graph in COO format"""
15
+ rev_edge_index: np.ndarray
16
+ """A array of shape ``E`` that maps from an edge index to the index of the source of the reverse edge in :attr:`edge_index` attribute."""
17
+ name: str | None = None # Add SMILES string as an optional attribute
chemprop-updated/chemprop/data/samplers.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import chain
2
+ from typing import Iterator, Optional
3
+
4
+ import numpy as np
5
+ from torch.utils.data import Sampler
6
+
7
+
8
+ class SeededSampler(Sampler):
9
+ """A :class`SeededSampler` is a class for iterating through a dataset in a randomly seeded
10
+ fashion"""
11
+
12
+ def __init__(self, N: int, seed: int):
13
+ if seed is None:
14
+ raise ValueError("arg 'seed' was `None`! A SeededSampler must be seeded!")
15
+
16
+ self.idxs = np.arange(N)
17
+ self.rg = np.random.default_rng(seed)
18
+
19
+ def __iter__(self) -> Iterator[int]:
20
+ """an iterator over indices to sample."""
21
+ self.rg.shuffle(self.idxs)
22
+
23
+ return iter(self.idxs)
24
+
25
+ def __len__(self) -> int:
26
+ """the number of indices that will be sampled."""
27
+ return len(self.idxs)
28
+
29
+
30
+ class ClassBalanceSampler(Sampler):
31
+ """A :class:`ClassBalanceSampler` samples data from a :class:`MolGraphDataset` such that
32
+ positive and negative classes are equally sampled
33
+
34
+ Parameters
35
+ ----------
36
+ dataset : MolGraphDataset
37
+ the dataset from which to sample
38
+ seed : int
39
+ the random seed to use for shuffling (only used when `shuffle` is `True`)
40
+ shuffle : bool, default=False
41
+ whether to shuffle the data during sampling
42
+ """
43
+
44
+ def __init__(self, Y: np.ndarray, seed: Optional[int] = None, shuffle: bool = False):
45
+ self.shuffle = shuffle
46
+ self.rg = np.random.default_rng(seed)
47
+
48
+ idxs = np.arange(len(Y))
49
+ actives = Y.any(1)
50
+
51
+ self.pos_idxs = idxs[actives]
52
+ self.neg_idxs = idxs[~actives]
53
+
54
+ self.length = 2 * min(len(self.pos_idxs), len(self.neg_idxs))
55
+
56
+ def __iter__(self) -> Iterator[int]:
57
+ """an iterator over indices to sample."""
58
+ if self.shuffle:
59
+ self.rg.shuffle(self.pos_idxs)
60
+ self.rg.shuffle(self.neg_idxs)
61
+
62
+ return chain(*zip(self.pos_idxs, self.neg_idxs))
63
+
64
+ def __len__(self) -> int:
65
+ """the number of indices that will be sampled."""
66
+ return self.length
chemprop-updated/chemprop/data/splitting.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Iterable, Sequence
2
+ import copy
3
+ from enum import auto
4
+ import logging
5
+
6
+ from astartes import train_test_split, train_val_test_split
7
+ from astartes.molecules import train_test_split_molecules, train_val_test_split_molecules
8
+ import numpy as np
9
+ from rdkit import Chem
10
+
11
+ from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
12
+ from chemprop.utils.utils import EnumMapping
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ Datapoints = Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint]
17
+ MulticomponentDatapoints = Sequence[Datapoints]
18
+
19
+
20
+ class SplitType(EnumMapping):
21
+ SCAFFOLD_BALANCED = auto()
22
+ RANDOM_WITH_REPEATED_SMILES = auto()
23
+ RANDOM = auto()
24
+ KENNARD_STONE = auto()
25
+ KMEANS = auto()
26
+
27
+
28
+ def make_split_indices(
29
+ mols: Sequence[Chem.Mol],
30
+ split: SplitType | str = "random",
31
+ sizes: tuple[float, float, float] = (0.8, 0.1, 0.1),
32
+ seed: int = 0,
33
+ num_replicates: int = 1,
34
+ num_folds: None = None,
35
+ ) -> tuple[list[list[int]], ...]:
36
+ """Splits data into training, validation, and test splits.
37
+
38
+ Parameters
39
+ ----------
40
+ mols : Sequence[Chem.Mol]
41
+ Sequence of RDKit molecules to use for structure based splitting
42
+ split : SplitType | str, optional
43
+ Split type, one of ~chemprop.data.utils.SplitType, by default "random"
44
+ sizes : tuple[float, float, float], optional
45
+ 3-tuple with the proportions of data in the train, validation, and test sets, by default
46
+ (0.8, 0.1, 0.1). Set the middle value to 0 for a two way split.
47
+ seed : int, optional
48
+ The random seed passed to astartes, by default 0
49
+ num_replicates : int, optional
50
+ Number of replicates, by default 1
51
+ num_folds : None, optional
52
+ This argument was removed in v2.1 - use `num_replicates` instead.
53
+
54
+ Returns
55
+ -------
56
+ tuple[list[list[int]], ...]
57
+ 2- or 3-member tuple containing num_replicates length lists of training, validation, and testing indexes.
58
+
59
+ .. important::
60
+ Validation may or may not be present
61
+
62
+ Raises
63
+ ------
64
+ ValueError
65
+ Requested split sizes tuple not of length 3
66
+ ValueError
67
+ Unsupported split method requested
68
+ """
69
+ if num_folds is not None:
70
+ raise RuntimeError("This argument was removed in v2.1 - use `num_replicates` instead.")
71
+ if num_replicates == 1:
72
+ logger.warning(
73
+ "The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)"
74
+ )
75
+ if (num_splits := len(sizes)) != 3:
76
+ raise ValueError(
77
+ f"Specify sizes for train, validation, and test (got {num_splits} values)."
78
+ )
79
+ # typically include a validation set
80
+ include_val = True
81
+ split_fun = train_val_test_split
82
+ mol_split_fun = train_val_test_split_molecules
83
+ # default sampling arguments for astartes sampler
84
+ astartes_kwargs = dict(
85
+ train_size=sizes[0], test_size=sizes[2], return_indices=True, random_state=seed
86
+ )
87
+ # if no validation set, reassign the splitting functions
88
+ if sizes[1] == 0.0:
89
+ include_val = False
90
+ split_fun = train_test_split
91
+ mol_split_fun = train_test_split_molecules
92
+ else:
93
+ astartes_kwargs["val_size"] = sizes[1]
94
+
95
+ n_datapoints = len(mols)
96
+ train_replicates, val_replicates, test_replicates = [], [], []
97
+ for _ in range(num_replicates):
98
+ train, val, test = None, None, None
99
+ match SplitType.get(split):
100
+ case SplitType.SCAFFOLD_BALANCED:
101
+ mols_without_atommaps = []
102
+ for mol in mols:
103
+ copied_mol = copy.deepcopy(mol)
104
+ for atom in copied_mol.GetAtoms():
105
+ atom.SetAtomMapNum(0)
106
+ mols_without_atommaps.append(copied_mol)
107
+ result = mol_split_fun(
108
+ np.array(mols_without_atommaps), sampler="scaffold", **astartes_kwargs
109
+ )
110
+ train, val, test = _unpack_astartes_result(result, include_val)
111
+
112
+ # Use to constrain data with the same smiles go in the same split.
113
+ case SplitType.RANDOM_WITH_REPEATED_SMILES:
114
+ # get two arrays: one of all the smiles strings, one of just the unique
115
+ all_smiles = np.array([Chem.MolToSmiles(mol) for mol in mols])
116
+ unique_smiles = np.unique(all_smiles)
117
+
118
+ # save a mapping of smiles -> all the indices that it appeared at
119
+ smiles_indices = {}
120
+ for smiles in unique_smiles:
121
+ smiles_indices[smiles] = np.where(all_smiles == smiles)[0].tolist()
122
+
123
+ # randomly split the unique smiles
124
+ result = split_fun(
125
+ np.arange(len(unique_smiles)), sampler="random", **astartes_kwargs
126
+ )
127
+ train_idxs, val_idxs, test_idxs = _unpack_astartes_result(result, include_val)
128
+
129
+ # convert these to the 'actual' indices from the original list using the dict we made
130
+ train = sum((smiles_indices[unique_smiles[i]] for i in train_idxs), [])
131
+ val = sum((smiles_indices[unique_smiles[j]] for j in val_idxs), [])
132
+ test = sum((smiles_indices[unique_smiles[k]] for k in test_idxs), [])
133
+
134
+ case SplitType.RANDOM:
135
+ result = split_fun(np.arange(n_datapoints), sampler="random", **astartes_kwargs)
136
+ train, val, test = _unpack_astartes_result(result, include_val)
137
+
138
+ case SplitType.KENNARD_STONE:
139
+ result = mol_split_fun(
140
+ np.array(mols),
141
+ sampler="kennard_stone",
142
+ hopts=dict(metric="jaccard"),
143
+ fingerprint="morgan_fingerprint",
144
+ fprints_hopts=dict(n_bits=2048),
145
+ **astartes_kwargs,
146
+ )
147
+ train, val, test = _unpack_astartes_result(result, include_val)
148
+
149
+ case SplitType.KMEANS:
150
+ result = mol_split_fun(
151
+ np.array(mols),
152
+ sampler="kmeans",
153
+ hopts=dict(metric="jaccard"),
154
+ fingerprint="morgan_fingerprint",
155
+ fprints_hopts=dict(n_bits=2048),
156
+ **astartes_kwargs,
157
+ )
158
+ train, val, test = _unpack_astartes_result(result, include_val)
159
+
160
+ case _:
161
+ raise RuntimeError("Unreachable code reached!")
162
+ train_replicates.append(train)
163
+ val_replicates.append(val)
164
+ test_replicates.append(test)
165
+ astartes_kwargs["random_state"] += 1
166
+ return train_replicates, val_replicates, test_replicates
167
+
168
+
169
+ def _unpack_astartes_result(
170
+ result: tuple, include_val: bool
171
+ ) -> tuple[list[int], list[int], list[int]]:
172
+ """Helper function to partition input data based on output of astartes sampler
173
+
174
+ Parameters
175
+ -----------
176
+ result: tuple
177
+ Output from call to astartes containing the split indices
178
+ include_val: bool
179
+ True if a validation set is included, False otherwise.
180
+
181
+ Returns
182
+ ---------
183
+ train: list[int]
184
+ val: list[int]
185
+ .. important::
186
+ validation possibly empty
187
+ test: list[int]
188
+ """
189
+ train_idxs, val_idxs, test_idxs = [], [], []
190
+ # astartes returns a set of lists containing the data, clusters (if applicable)
191
+ # and indices (always last), so we pull out the indices
192
+ if include_val:
193
+ train_idxs, val_idxs, test_idxs = result[-3], result[-2], result[-1]
194
+ else:
195
+ train_idxs, test_idxs = result[-2], result[-1]
196
+ return list(train_idxs), list(val_idxs), list(test_idxs)
197
+
198
+
199
+ def split_data_by_indices(
200
+ data: Datapoints | MulticomponentDatapoints,
201
+ train_indices: Iterable[Iterable[int]] | None = None,
202
+ val_indices: Iterable[Iterable[int]] | None = None,
203
+ test_indices: Iterable[Iterable[int]] | None = None,
204
+ ):
205
+ """Splits data into training, validation, and test groups based on split indices given."""
206
+
207
+ train_data = _splitter_helper(data, train_indices)
208
+ val_data = _splitter_helper(data, val_indices)
209
+ test_data = _splitter_helper(data, test_indices)
210
+
211
+ return train_data, val_data, test_data
212
+
213
+
214
+ def _splitter_helper(data, indices):
215
+ if indices is None:
216
+ return None
217
+
218
+ if isinstance(data[0], (MoleculeDatapoint, ReactionDatapoint)):
219
+ datapoints = data
220
+ idxss = indices
221
+ return [[datapoints[idx] for idx in idxs] for idxs in idxss]
222
+ else:
223
+ datapointss = data
224
+ idxss = indices
225
+ return [[[datapoints[idx] for idx in idxs] for datapoints in datapointss] for idxs in idxss]
chemprop-updated/chemprop/exceptions.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable
2
+
3
+ from chemprop.utils import pretty_shape
4
+
5
+
6
+ class InvalidShapeError(ValueError):
7
+ def __init__(self, var_name: str, received: Iterable[int], expected: Iterable[int]):
8
+ message = (
9
+ f"arg '{var_name}' has incorrect shape! "
10
+ f"got: `{pretty_shape(received)}`. expected: `{pretty_shape(expected)}`"
11
+ )
12
+ super().__init__(message)
chemprop-updated/chemprop/features/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.46 kB). View file
 
chemprop-updated/chemprop/features/__pycache__/features_generators.cpython-37.pyc ADDED
Binary file (5.71 kB). View file
 
chemprop-updated/chemprop/features/__pycache__/featurization.cpython-37.pyc ADDED
Binary file (24.7 kB). View file
 
chemprop-updated/chemprop/features/__pycache__/utils.cpython-37.pyc ADDED
Binary file (4.86 kB). View file
 
chemprop-updated/chemprop/featurizers/__init__.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .atom import AtomFeatureMode, MultiHotAtomFeaturizer, get_multi_hot_atom_featurizer
2
+ from .base import Featurizer, GraphFeaturizer, S, T, VectorFeaturizer
3
+ from .bond import MultiHotBondFeaturizer
4
+ from .molecule import (
5
+ BinaryFeaturizerMixin,
6
+ CountFeaturizerMixin,
7
+ MoleculeFeaturizerRegistry,
8
+ MorganBinaryFeaturizer,
9
+ MorganCountFeaturizer,
10
+ MorganFeaturizerMixin,
11
+ RDKit2DFeaturizer,
12
+ V1RDKit2DFeaturizer,
13
+ V1RDKit2DNormalizedFeaturizer,
14
+ )
15
+ from .molgraph import (
16
+ CGRFeaturizer,
17
+ CondensedGraphOfReactionFeaturizer,
18
+ MolGraphCache,
19
+ MolGraphCacheFacade,
20
+ MolGraphCacheOnTheFly,
21
+ RxnMode,
22
+ SimpleMoleculeMolGraphFeaturizer,
23
+ )
24
+
25
+ __all__ = [
26
+ "Featurizer",
27
+ "S",
28
+ "T",
29
+ "VectorFeaturizer",
30
+ "GraphFeaturizer",
31
+ "MultiHotAtomFeaturizer",
32
+ "AtomFeatureMode",
33
+ "get_multi_hot_atom_featurizer",
34
+ "MultiHotBondFeaturizer",
35
+ "MolGraphCacheFacade",
36
+ "MolGraphCache",
37
+ "MolGraphCacheOnTheFly",
38
+ "SimpleMoleculeMolGraphFeaturizer",
39
+ "CondensedGraphOfReactionFeaturizer",
40
+ "CGRFeaturizer",
41
+ "RxnMode",
42
+ "MoleculeFeaturizer",
43
+ "MorganFeaturizerMixin",
44
+ "BinaryFeaturizerMixin",
45
+ "CountFeaturizerMixin",
46
+ "MorganBinaryFeaturizer",
47
+ "MorganCountFeaturizer",
48
+ "RDKit2DFeaturizer",
49
+ "MoleculeFeaturizerRegistry",
50
+ "V1RDKit2DFeaturizer",
51
+ "V1RDKit2DNormalizedFeaturizer",
52
+ ]
chemprop-updated/chemprop/featurizers/atom.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import auto
2
+ from typing import Sequence
3
+
4
+ import numpy as np
5
+ from rdkit.Chem.rdchem import Atom, HybridizationType
6
+
7
+ from chemprop.featurizers.base import VectorFeaturizer
8
+ from chemprop.utils.utils import EnumMapping
9
+
10
+
11
+ class MultiHotAtomFeaturizer(VectorFeaturizer[Atom]):
12
+ """A :class:`MultiHotAtomFeaturizer` uses a multi-hot encoding to featurize atoms.
13
+
14
+ .. seealso::
15
+ The class provides three default parameterization schemes:
16
+
17
+ * :meth:`MultiHotAtomFeaturizer.v1`
18
+ * :meth:`MultiHotAtomFeaturizer.v2`
19
+ * :meth:`MultiHotAtomFeaturizer.organic`
20
+
21
+ The generated atom features are ordered as follows:
22
+ * atomic number
23
+ * degree
24
+ * formal charge
25
+ * chiral tag
26
+ * number of hydrogens
27
+ * hybridization
28
+ * aromaticity
29
+ * mass
30
+
31
+ .. important::
32
+ Each feature, except for aromaticity and mass, includes a pad for unknown values.
33
+
34
+ Parameters
35
+ ----------
36
+ atomic_nums : Sequence[int]
37
+ the choices for atom type denoted by atomic number. Ex: ``[4, 5, 6]`` for C, N and O.
38
+ degrees : Sequence[int]
39
+ the choices for number of bonds an atom is engaged in.
40
+ formal_charges : Sequence[int]
41
+ the choices for integer electronic charge assigned to an atom.
42
+ chiral_tags : Sequence[int]
43
+ the choices for an atom's chiral tag. See :class:`rdkit.Chem.rdchem.ChiralType` for possible integer values.
44
+ num_Hs : Sequence[int]
45
+ the choices for number of bonded hydrogen atoms.
46
+ hybridizations : Sequence[int]
47
+ the choices for an atom’s hybridization type. See :class:`rdkit.Chem.rdchem.HybridizationType` for possible integer values.
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ atomic_nums: Sequence[int],
53
+ degrees: Sequence[int],
54
+ formal_charges: Sequence[int],
55
+ chiral_tags: Sequence[int],
56
+ num_Hs: Sequence[int],
57
+ hybridizations: Sequence[int],
58
+ ):
59
+ self.atomic_nums = {j: i for i, j in enumerate(atomic_nums)}
60
+ self.degrees = {i: i for i in degrees}
61
+ self.formal_charges = {j: i for i, j in enumerate(formal_charges)}
62
+ self.chiral_tags = {i: i for i in chiral_tags}
63
+ self.num_Hs = {i: i for i in num_Hs}
64
+ self.hybridizations = {ht: i for i, ht in enumerate(hybridizations)}
65
+
66
+ self._subfeats: list[dict] = [
67
+ self.atomic_nums,
68
+ self.degrees,
69
+ self.formal_charges,
70
+ self.chiral_tags,
71
+ self.num_Hs,
72
+ self.hybridizations,
73
+ ]
74
+ subfeat_sizes = [
75
+ 1 + len(self.atomic_nums),
76
+ 1 + len(self.degrees),
77
+ 1 + len(self.formal_charges),
78
+ 1 + len(self.chiral_tags),
79
+ 1 + len(self.num_Hs),
80
+ 1 + len(self.hybridizations),
81
+ 1,
82
+ 1,
83
+ ]
84
+ self.__size = sum(subfeat_sizes)
85
+
86
+ def __len__(self) -> int:
87
+ return self.__size
88
+
89
+ def __call__(self, a: Atom | None) -> np.ndarray:
90
+ x = np.zeros(self.__size)
91
+
92
+ if a is None:
93
+ return x
94
+
95
+ feats = [
96
+ a.GetAtomicNum(),
97
+ a.GetTotalDegree(),
98
+ a.GetFormalCharge(),
99
+ int(a.GetChiralTag()),
100
+ int(a.GetTotalNumHs()),
101
+ a.GetHybridization(),
102
+ ]
103
+ i = 0
104
+ for feat, choices in zip(feats, self._subfeats):
105
+ j = choices.get(feat, len(choices))
106
+ x[i + j] = 1
107
+ i += len(choices) + 1
108
+ x[i] = int(a.GetIsAromatic())
109
+ x[i + 1] = 0.01 * a.GetMass()
110
+
111
+ return x
112
+
113
+ def num_only(self, a: Atom) -> np.ndarray:
114
+ """featurize the atom by setting only the atomic number bit"""
115
+ x = np.zeros(len(self))
116
+
117
+ if a is None:
118
+ return x
119
+
120
+ i = self.atomic_nums.get(a.GetAtomicNum(), len(self.atomic_nums))
121
+ x[i] = 1
122
+
123
+ return x
124
+
125
+ @classmethod
126
+ def v1(cls, max_atomic_num: int = 100):
127
+ """The original implementation used in Chemprop V1 [1]_, [2]_.
128
+
129
+ Parameters
130
+ ----------
131
+ max_atomic_num : int, default=100
132
+ Include a bit for all atomic numbers in the interval :math:`[1, \mathtt{max\_atomic\_num}]`
133
+
134
+ References
135
+ -----------
136
+ .. [1] Yang, K.; Swanson, K.; Jin, W.; Coley, C.; Eiden, P.; Gao, H.; Guzman-Perez, A.; Hopper, T.;
137
+ Kelley, B.; Mathea, M.; Palmer, A. "Analyzing Learned Molecular Representations for Property Prediction."
138
+ J. Chem. Inf. Model. 2019, 59 (8), 3370–3388. https://doi.org/10.1021/acs.jcim.9b00237
139
+ .. [2] Heid, E.; Greenman, K.P.; Chung, Y.; Li, S.C.; Graff, D.E.; Vermeire, F.H.; Wu, H.; Green, W.H.; McGill,
140
+ C.J. "Chemprop: A machine learning package for chemical property prediction." J. Chem. Inf. Model. 2024,
141
+ 64 (1), 9–17. https://doi.org/10.1021/acs.jcim.3c01250
142
+ """
143
+
144
+ return cls(
145
+ atomic_nums=list(range(1, max_atomic_num + 1)),
146
+ degrees=list(range(6)),
147
+ formal_charges=[-1, -2, 1, 2, 0],
148
+ chiral_tags=list(range(4)),
149
+ num_Hs=list(range(5)),
150
+ hybridizations=[
151
+ HybridizationType.SP,
152
+ HybridizationType.SP2,
153
+ HybridizationType.SP3,
154
+ HybridizationType.SP3D,
155
+ HybridizationType.SP3D2,
156
+ ],
157
+ )
158
+
159
+ @classmethod
160
+ def v2(cls):
161
+ """An implementation that includes an atom type bit for all elements in the first four rows of the periodic table plus iodine."""
162
+
163
+ return cls(
164
+ atomic_nums=list(range(1, 37)) + [53],
165
+ degrees=list(range(6)),
166
+ formal_charges=[-1, -2, 1, 2, 0],
167
+ chiral_tags=list(range(4)),
168
+ num_Hs=list(range(5)),
169
+ hybridizations=[
170
+ HybridizationType.S,
171
+ HybridizationType.SP,
172
+ HybridizationType.SP2,
173
+ HybridizationType.SP2D,
174
+ HybridizationType.SP3,
175
+ HybridizationType.SP3D,
176
+ HybridizationType.SP3D2,
177
+ ],
178
+ )
179
+
180
+ @classmethod
181
+ def organic(cls):
182
+ r"""A specific parameterization intended for use with organic or drug-like molecules.
183
+
184
+ This parameterization features:
185
+ 1. includes an atomic number bit only for H, B, C, N, O, F, Si, P, S, Cl, Br, and I atoms
186
+ 2. a hybridization bit for :math:`s, sp, sp^2` and :math:`sp^3` hybridizations.
187
+ """
188
+
189
+ return cls(
190
+ atomic_nums=[1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53],
191
+ degrees=list(range(6)),
192
+ formal_charges=[-1, -2, 1, 2, 0],
193
+ chiral_tags=list(range(4)),
194
+ num_Hs=list(range(5)),
195
+ hybridizations=[
196
+ HybridizationType.S,
197
+ HybridizationType.SP,
198
+ HybridizationType.SP2,
199
+ HybridizationType.SP3,
200
+ ],
201
+ )
202
+
203
+
204
+ class RIGRAtomFeaturizer(VectorFeaturizer[Atom]):
205
+ """A :class:`RIGRAtomFeaturizer` uses a multi-hot encoding to featurize atoms using resonance-invariant features.
206
+
207
+ The generated atom features are ordered as follows:
208
+ * atomic number
209
+ * degree
210
+ * number of hydrogens
211
+ * mass
212
+ """
213
+
214
+ def __init__(
215
+ self,
216
+ atomic_nums: Sequence[int] | None = None,
217
+ degrees: Sequence[int] | None = None,
218
+ num_Hs: Sequence[int] | None = None,
219
+ ):
220
+ self.atomic_nums = {j: i for i, j in enumerate(atomic_nums or list(range(1, 37)) + [53])}
221
+ self.degrees = {i: i for i in (degrees or list(range(6)))}
222
+ self.num_Hs = {i: i for i in (num_Hs or list(range(5)))}
223
+
224
+ self._subfeats: list[dict] = [self.atomic_nums, self.degrees, self.num_Hs]
225
+ subfeat_sizes = [1 + len(self.atomic_nums), 1 + len(self.degrees), 1 + len(self.num_Hs), 1]
226
+ self.__size = sum(subfeat_sizes)
227
+
228
+ def __len__(self) -> int:
229
+ return self.__size
230
+
231
+ def __call__(self, a: Atom | None) -> np.ndarray:
232
+ x = np.zeros(self.__size)
233
+
234
+ if a is None:
235
+ return x
236
+
237
+ feats = [a.GetAtomicNum(), a.GetTotalDegree(), int(a.GetTotalNumHs())]
238
+ i = 0
239
+ for feat, choices in zip(feats, self._subfeats):
240
+ j = choices.get(feat, len(choices))
241
+ x[i + j] = 1
242
+ i += len(choices) + 1
243
+ x[i] = 0.01 * a.GetMass() # scaled to about the same range as other features
244
+
245
+ return x
246
+
247
+ def num_only(self, a: Atom) -> np.ndarray:
248
+ """featurize the atom by setting only the atomic number bit"""
249
+ x = np.zeros(len(self))
250
+
251
+ if a is None:
252
+ return x
253
+
254
+ i = self.atomic_nums.get(a.GetAtomicNum(), len(self.atomic_nums))
255
+ x[i] = 1
256
+
257
+ return x
258
+
259
+
260
+ class AtomFeatureMode(EnumMapping):
261
+ """The mode of an atom is used for featurization into a `MolGraph`"""
262
+
263
+ V1 = auto()
264
+ V2 = auto()
265
+ ORGANIC = auto()
266
+ RIGR = auto()
267
+
268
+
269
+ def get_multi_hot_atom_featurizer(mode: str | AtomFeatureMode) -> MultiHotAtomFeaturizer:
270
+ """Build the corresponding multi-hot atom featurizer."""
271
+ match AtomFeatureMode.get(mode):
272
+ case AtomFeatureMode.V1:
273
+ return MultiHotAtomFeaturizer.v1()
274
+ case AtomFeatureMode.V2:
275
+ return MultiHotAtomFeaturizer.v2()
276
+ case AtomFeatureMode.ORGANIC:
277
+ return MultiHotAtomFeaturizer.organic()
278
+ case AtomFeatureMode.RIGR:
279
+ return RIGRAtomFeaturizer()
280
+ case _:
281
+ raise RuntimeError("unreachable code reached!")
chemprop-updated/chemprop/featurizers/base.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from collections.abc import Sized
3
+ from typing import Generic, TypeVar
4
+
5
+ import numpy as np
6
+
7
+ from chemprop.data.molgraph import MolGraph
8
+
9
+ S = TypeVar("S")
10
+ T = TypeVar("T")
11
+
12
+
13
+ class Featurizer(Generic[S, T]):
14
+ """An :class:`Featurizer` featurizes inputs type ``S`` into outputs of
15
+ type ``T``."""
16
+
17
+ @abstractmethod
18
+ def __call__(self, input: S, *args, **kwargs) -> T:
19
+ """featurize an input"""
20
+
21
+
22
+ class VectorFeaturizer(Featurizer[S, np.ndarray], Sized):
23
+ ...
24
+
25
+
26
+ class GraphFeaturizer(Featurizer[S, MolGraph]):
27
+ @property
28
+ @abstractmethod
29
+ def shape(self) -> tuple[int, int]:
30
+ ...