Spaces:
Build error
Build error
Upload 111 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- chemprop-updated/chemprop/__init__.py +5 -0
- chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/cli/common.py +216 -0
- chemprop-updated/chemprop/cli/conf.py +9 -0
- chemprop-updated/chemprop/cli/convert.py +55 -0
- chemprop-updated/chemprop/cli/fingerprint.py +185 -0
- chemprop-updated/chemprop/cli/hpopt.py +540 -0
- chemprop-updated/chemprop/cli/main.py +85 -0
- chemprop-updated/chemprop/cli/predict.py +447 -0
- chemprop-updated/chemprop/cli/train.py +1343 -0
- chemprop-updated/chemprop/cli/utils/__init__.py +30 -0
- chemprop-updated/chemprop/cli/utils/actions.py +19 -0
- chemprop-updated/chemprop/cli/utils/args.py +34 -0
- chemprop-updated/chemprop/cli/utils/command.py +24 -0
- chemprop-updated/chemprop/cli/utils/parsing.py +457 -0
- chemprop-updated/chemprop/cli/utils/utils.py +31 -0
- chemprop-updated/chemprop/conf.py +6 -0
- chemprop-updated/chemprop/data/__init__.py +41 -0
- chemprop-updated/chemprop/data/__pycache__/__init__.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/data/__pycache__/data.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/data/__pycache__/scaffold.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/data/__pycache__/scaler.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/data/__pycache__/utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/data/collate.py +123 -0
- chemprop-updated/chemprop/data/dataloader.py +71 -0
- chemprop-updated/chemprop/data/datapoints.py +150 -0
- chemprop-updated/chemprop/data/datasets.py +475 -0
- chemprop-updated/chemprop/data/molgraph.py +17 -0
- chemprop-updated/chemprop/data/samplers.py +66 -0
- chemprop-updated/chemprop/data/splitting.py +225 -0
- chemprop-updated/chemprop/exceptions.py +12 -0
- chemprop-updated/chemprop/features/__pycache__/__init__.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/features/__pycache__/features_generators.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/features/__pycache__/featurization.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/features/__pycache__/utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/featurizers/__init__.py +52 -0
- chemprop-updated/chemprop/featurizers/atom.py +281 -0
- chemprop-updated/chemprop/featurizers/base.py +30 -0
chemprop-updated/chemprop/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import data, exceptions, featurizers, models, nn, schedulers, utils
|
2 |
+
|
3 |
+
__all__ = ["data", "featurizers", "models", "nn", "utils", "exceptions", "schedulers"]
|
4 |
+
|
5 |
+
__version__ = "2.1.2"
|
chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (743 Bytes). View file
|
|
chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc
ADDED
Binary file (33.7 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc
ADDED
Binary file (430 Bytes). View file
|
|
chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc
ADDED
Binary file (11.1 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc
ADDED
Binary file (6.15 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc
ADDED
Binary file (14.2 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc
ADDED
Binary file (3.12 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc
ADDED
Binary file (8.13 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc
ADDED
Binary file (1.43 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc
ADDED
Binary file (2.82 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc
ADDED
Binary file (11.4 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc
ADDED
Binary file (5.1 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (26.5 kB). View file
|
|
chemprop-updated/chemprop/cli/common.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from chemprop.cli.utils import LookupAction
|
6 |
+
from chemprop.cli.utils.args import uppercase
|
7 |
+
from chemprop.featurizers import AtomFeatureMode, MoleculeFeaturizerRegistry, RxnMode
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def add_common_args(parser: ArgumentParser) -> ArgumentParser:
|
13 |
+
data_args = parser.add_argument_group("Shared input data args")
|
14 |
+
data_args.add_argument(
|
15 |
+
"-s",
|
16 |
+
"--smiles-columns",
|
17 |
+
nargs="+",
|
18 |
+
help="Column names in the input CSV containing SMILES strings (uses the 0th column by default)",
|
19 |
+
)
|
20 |
+
data_args.add_argument(
|
21 |
+
"-r",
|
22 |
+
"--reaction-columns",
|
23 |
+
nargs="+",
|
24 |
+
help="Column names in the input CSV containing reaction SMILES in the format ``REACTANT>AGENT>PRODUCT``, where 'AGENT' is optional",
|
25 |
+
)
|
26 |
+
data_args.add_argument(
|
27 |
+
"--no-header-row",
|
28 |
+
action="store_true",
|
29 |
+
help="Turn off using the first row in the input CSV as column names",
|
30 |
+
)
|
31 |
+
|
32 |
+
dataloader_args = parser.add_argument_group("Dataloader args")
|
33 |
+
dataloader_args.add_argument(
|
34 |
+
"-n",
|
35 |
+
"--num-workers",
|
36 |
+
type=int,
|
37 |
+
default=0,
|
38 |
+
help="""Number of workers for parallel data loading where 0 means sequential
|
39 |
+
(Warning: setting ``num_workers`` to a value greater than 0 can cause hangs on Windows and MacOS)""",
|
40 |
+
)
|
41 |
+
dataloader_args.add_argument("-b", "--batch-size", type=int, default=64, help="Batch size")
|
42 |
+
|
43 |
+
parser.add_argument(
|
44 |
+
"--accelerator", default="auto", help="Passed directly to the lightning ``Trainer()``"
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--devices",
|
48 |
+
default="auto",
|
49 |
+
help="Passed directly to the lightning ``Trainer()`` (must be a single string of comma separated devices, e.g. '1, 2' if specifying multiple devices)",
|
50 |
+
)
|
51 |
+
|
52 |
+
featurization_args = parser.add_argument_group("Featurization args")
|
53 |
+
featurization_args.add_argument(
|
54 |
+
"--rxn-mode",
|
55 |
+
"--reaction-mode",
|
56 |
+
type=uppercase,
|
57 |
+
default="REAC_DIFF",
|
58 |
+
choices=list(RxnMode.keys()),
|
59 |
+
help="""Choices for construction of atom and bond features for reactions (case insensitive):
|
60 |
+
|
61 |
+
- ``REAC_PROD``: concatenates the reactants feature with the products feature
|
62 |
+
- ``REAC_DIFF``: concatenates the reactants feature with the difference in features between reactants and products (Default)
|
63 |
+
- ``PROD_DIFF``: concatenates the products feature with the difference in features between reactants and products
|
64 |
+
- ``REAC_PROD_BALANCE``: concatenates the reactants feature with the products feature, balances imbalanced reactions
|
65 |
+
- ``REAC_DIFF_BALANCE``: concatenates the reactants feature with the difference in features between reactants and products, balances imbalanced reactions
|
66 |
+
- ``PROD_DIFF_BALANCE``: concatenates the products feature with the difference in features between reactants and products, balances imbalanced reactions""",
|
67 |
+
)
|
68 |
+
# TODO: Update documenation for multi_hot_atom_featurizer_mode
|
69 |
+
featurization_args.add_argument(
|
70 |
+
"--multi-hot-atom-featurizer-mode",
|
71 |
+
type=uppercase,
|
72 |
+
default="V2",
|
73 |
+
choices=list(AtomFeatureMode.keys()),
|
74 |
+
help="""Choices for multi-hot atom featurization scheme. This will affect both non-reaction and reaction feturization (case insensitive):
|
75 |
+
|
76 |
+
- ``V1``: Corresponds to the original configuration employed in the Chemprop V1
|
77 |
+
- ``V2``: Tailored for a broad range of molecules, this configuration encompasses all elements in the first four rows of the periodic table, along with iodine. It is the default in Chemprop V2.
|
78 |
+
- ``ORGANIC``: This configuration is designed specifically for use with organic molecules for drug research and development and includes a subset of elements most common in organic chemistry, including H, B, C, N, O, F, Si, P, S, Cl, Br, and I.
|
79 |
+
- ``RIGR``: Modified V2 (default) featurizer using only the resonance-invariant atom and bond features.""",
|
80 |
+
)
|
81 |
+
featurization_args.add_argument(
|
82 |
+
"--keep-h",
|
83 |
+
action="store_true",
|
84 |
+
help="Whether hydrogens explicitly specified in input should be kept in the mol graph",
|
85 |
+
)
|
86 |
+
featurization_args.add_argument(
|
87 |
+
"--add-h", action="store_true", help="Whether hydrogens should be added to the mol graph"
|
88 |
+
)
|
89 |
+
data_args.add_argument(
|
90 |
+
"--ignore-chirality",
|
91 |
+
action="store_true",
|
92 |
+
help="Ignore chirality information in the input SMILES",
|
93 |
+
)
|
94 |
+
featurization_args.add_argument(
|
95 |
+
"--molecule-featurizers",
|
96 |
+
"--features-generators",
|
97 |
+
nargs="+",
|
98 |
+
action=LookupAction(MoleculeFeaturizerRegistry),
|
99 |
+
help="Method(s) of generating molecule features to use as extra descriptors",
|
100 |
+
)
|
101 |
+
# TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
|
102 |
+
# featurization_args.add_argument(
|
103 |
+
# "--features-generators", nargs="+", help="Renamed to `--molecule-featurizers`."
|
104 |
+
# )
|
105 |
+
featurization_args.add_argument(
|
106 |
+
"--descriptors-path",
|
107 |
+
type=Path,
|
108 |
+
help="Path to extra descriptors to concatenate to learned representation",
|
109 |
+
)
|
110 |
+
# TODO: Add in v2.1
|
111 |
+
# featurization_args.add_argument(
|
112 |
+
# "--phase-features-path",
|
113 |
+
# help="Path to features used to indicate the phase of the data in one-hot vector form. Used in spectra datatype.",
|
114 |
+
# )
|
115 |
+
featurization_args.add_argument(
|
116 |
+
"--no-descriptor-scaling", action="store_true", help="Turn off extra descriptor scaling"
|
117 |
+
)
|
118 |
+
featurization_args.add_argument(
|
119 |
+
"--no-atom-feature-scaling", action="store_true", help="Turn off extra atom feature scaling"
|
120 |
+
)
|
121 |
+
featurization_args.add_argument(
|
122 |
+
"--no-atom-descriptor-scaling",
|
123 |
+
action="store_true",
|
124 |
+
help="Turn off extra atom descriptor scaling",
|
125 |
+
)
|
126 |
+
featurization_args.add_argument(
|
127 |
+
"--no-bond-feature-scaling", action="store_true", help="Turn off extra bond feature scaling"
|
128 |
+
)
|
129 |
+
featurization_args.add_argument(
|
130 |
+
"--atom-features-path",
|
131 |
+
nargs="+",
|
132 |
+
action="append",
|
133 |
+
help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom features to supply before message passing (e.g., ``--atom-features-path 0 /path/to/features_0.npz``) indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-features-path [...] --atom-features-path [...]``).",
|
134 |
+
)
|
135 |
+
featurization_args.add_argument(
|
136 |
+
"--atom-descriptors-path",
|
137 |
+
nargs="+",
|
138 |
+
action="append",
|
139 |
+
help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom descriptors to supply after message passing (e.g., ``--atom-descriptors-path 0 /path/to/descriptors_0.npz`` indicates that the descriptors at the given path should be supplied to the 0-th component. To supply additional descriptors for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-descriptors-path [...] --atom-descriptors-path [...]``).",
|
140 |
+
)
|
141 |
+
featurization_args.add_argument(
|
142 |
+
"--bond-features-path",
|
143 |
+
nargs="+",
|
144 |
+
action="append",
|
145 |
+
help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional bond features to supply before message passing (e.g., ``--bond-features-path 0 /path/to/features_0.npz`` indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--bond-features-path [...] --bond-features-path [...]``).",
|
146 |
+
)
|
147 |
+
# TODO: Add in v2.2
|
148 |
+
# parser.add_argument(
|
149 |
+
# "--constraints-path",
|
150 |
+
# help="Path to constraints applied to atomic/bond properties prediction.",
|
151 |
+
# )
|
152 |
+
|
153 |
+
return parser
|
154 |
+
|
155 |
+
|
156 |
+
def process_common_args(args: Namespace) -> Namespace:
|
157 |
+
# TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
|
158 |
+
# if args.features_generators is not None:
|
159 |
+
# raise ArgumentError(
|
160 |
+
# argument=None,
|
161 |
+
# message="`--features-generators` has been renamed to `--molecule-featurizers`.",
|
162 |
+
# )
|
163 |
+
|
164 |
+
for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
|
165 |
+
inds_paths = getattr(args, key)
|
166 |
+
|
167 |
+
if not inds_paths:
|
168 |
+
continue
|
169 |
+
|
170 |
+
ind_path_dict = {}
|
171 |
+
|
172 |
+
for ind_path in inds_paths:
|
173 |
+
if len(ind_path) > 2:
|
174 |
+
raise ArgumentError(
|
175 |
+
argument=None,
|
176 |
+
message="Too many arguments were given for atom features/descriptors or bond features. It should be either a two-tuple of molecule index and a path, or a single path (assumed to be the 0-th molecule).",
|
177 |
+
)
|
178 |
+
|
179 |
+
if len(ind_path) == 1:
|
180 |
+
ind = 0
|
181 |
+
path = ind_path[0]
|
182 |
+
else:
|
183 |
+
ind, path = ind_path
|
184 |
+
|
185 |
+
if ind_path_dict.get(int(ind), None):
|
186 |
+
raise ArgumentError(
|
187 |
+
argument=None,
|
188 |
+
message=f"Duplicate atom features/descriptors or bond features given for molecule index {ind}",
|
189 |
+
)
|
190 |
+
|
191 |
+
ind_path_dict[int(ind)] = Path(path)
|
192 |
+
|
193 |
+
setattr(args, key, ind_path_dict)
|
194 |
+
|
195 |
+
return args
|
196 |
+
|
197 |
+
|
198 |
+
def validate_common_args(args):
|
199 |
+
pass
|
200 |
+
|
201 |
+
|
202 |
+
def find_models(model_paths: list[Path]):
|
203 |
+
collected_model_paths = []
|
204 |
+
|
205 |
+
for model_path in model_paths:
|
206 |
+
if model_path.suffix in [".ckpt", ".pt"]:
|
207 |
+
collected_model_paths.append(model_path)
|
208 |
+
elif model_path.is_dir():
|
209 |
+
collected_model_paths.extend(list(model_path.rglob("*.pt")))
|
210 |
+
else:
|
211 |
+
raise ArgumentError(
|
212 |
+
argument=None,
|
213 |
+
message=f"Expected a .ckpt or .pt file, or a directory. Got {model_path}",
|
214 |
+
)
|
215 |
+
|
216 |
+
return collected_model_paths
|
chemprop-updated/chemprop/cli/conf.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
LOG_DIR = Path(os.getenv("CHEMPROP_LOG_DIR", "chemprop_logs"))
|
7 |
+
LOG_LEVELS = {0: logging.INFO, 1: logging.DEBUG, -1: logging.WARNING, -2: logging.ERROR}
|
8 |
+
NOW = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
9 |
+
CHEMPROP_TRAIN_DIR = Path(os.getenv("CHEMPROP_TRAIN_DIR", "chemprop_training"))
|
chemprop-updated/chemprop/cli/convert.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from chemprop.cli.utils import Subcommand
|
7 |
+
from chemprop.utils.v1_to_v2 import convert_model_file_v1_to_v2
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class ConvertSubcommand(Subcommand):
|
13 |
+
COMMAND = "convert"
|
14 |
+
HELP = "Convert a v1 model checkpoint (.pt) to a v2 model checkpoint (.pt)."
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
18 |
+
parser.add_argument(
|
19 |
+
"-i",
|
20 |
+
"--input-path",
|
21 |
+
required=True,
|
22 |
+
type=Path,
|
23 |
+
help="Path to a v1 model .pt checkpoint file",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"-o",
|
27 |
+
"--output-path",
|
28 |
+
type=Path,
|
29 |
+
help="Path to which the converted model will be saved (``CURRENT_DIRECTORY/STEM_OF_INPUT_v2.pt`` by default)",
|
30 |
+
)
|
31 |
+
return parser
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def func(cls, args: Namespace):
|
35 |
+
if args.output_path is None:
|
36 |
+
args.output_path = Path(args.input_path.stem + "_v2.pt")
|
37 |
+
if args.output_path.suffix != ".pt":
|
38 |
+
raise ArgumentError(
|
39 |
+
argument=None, message=f"Output must be a `.pt` file. Got {args.output_path}"
|
40 |
+
)
|
41 |
+
|
42 |
+
logger.info(
|
43 |
+
f"Converting v1 model checkpoint '{args.input_path}' to v2 model checkpoint '{args.output_path}'..."
|
44 |
+
)
|
45 |
+
convert_model_file_v1_to_v2(args.input_path, args.output_path)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
parser = ArgumentParser()
|
50 |
+
parser = ConvertSubcommand.add_args(parser)
|
51 |
+
|
52 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
53 |
+
|
54 |
+
args = parser.parse_args()
|
55 |
+
ConvertSubcommand.func(args)
|
chemprop-updated/chemprop/cli/fingerprint.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from chemprop import data
|
11 |
+
from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
|
12 |
+
from chemprop.cli.predict import find_models
|
13 |
+
from chemprop.cli.utils import Subcommand, build_data_from_files, make_dataset
|
14 |
+
from chemprop.models import load_model
|
15 |
+
from chemprop.nn.metrics import LossFunctionRegistry
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class FingerprintSubcommand(Subcommand):
|
21 |
+
COMMAND = "fingerprint"
|
22 |
+
HELP = "Use a pretrained chemprop model to calculate learned representations."
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
26 |
+
parser = add_common_args(parser)
|
27 |
+
parser.add_argument(
|
28 |
+
"-i",
|
29 |
+
"--test-path",
|
30 |
+
required=True,
|
31 |
+
type=Path,
|
32 |
+
help="Path to an input CSV file containing SMILES",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"-o",
|
36 |
+
"--output",
|
37 |
+
"--preds-path",
|
38 |
+
type=Path,
|
39 |
+
help="Specify the path where predictions will be saved. If the file extension is .npz, they will be saved as a npz file. Otherwise, the predictions will be saved as a CSV. The index of the model will be appended to the filename's stem. By default, predictions will be saved to the same location as ``--test-path`` with '_fps' appended (e.g., 'PATH/TO/TEST_PATH_fps_0.csv').",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--model-paths",
|
43 |
+
"--model-path",
|
44 |
+
required=True,
|
45 |
+
type=Path,
|
46 |
+
nargs="+",
|
47 |
+
help="Specify location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, chemprop will recursively search and predict on all found (.pt) models.",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--ffn-block-index",
|
51 |
+
required=True,
|
52 |
+
type=int,
|
53 |
+
default=-1,
|
54 |
+
help="The index indicates which linear layer returns the encoding in the FFN. An index of 0 denotes the post-aggregation representation through a 0-layer MLP, while an index of 1 represents the output from the first linear layer in the FFN, and so forth.",
|
55 |
+
)
|
56 |
+
|
57 |
+
return parser
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def func(cls, args: Namespace):
|
61 |
+
args = process_common_args(args)
|
62 |
+
validate_common_args(args)
|
63 |
+
args = process_fingerprint_args(args)
|
64 |
+
main(args)
|
65 |
+
|
66 |
+
|
67 |
+
def process_fingerprint_args(args: Namespace) -> Namespace:
|
68 |
+
if args.test_path.suffix not in [".csv"]:
|
69 |
+
raise ArgumentError(
|
70 |
+
argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
|
71 |
+
)
|
72 |
+
if args.output is None:
|
73 |
+
args.output = args.test_path.parent / (args.test_path.stem + "_fps.csv")
|
74 |
+
if args.output.suffix not in [".csv", ".npz"]:
|
75 |
+
raise ArgumentError(
|
76 |
+
argument=None, message=f"Output must be a CSV or NPZ file. Got '{args.output}'."
|
77 |
+
)
|
78 |
+
return args
|
79 |
+
|
80 |
+
|
81 |
+
def make_fingerprint_for_model(
|
82 |
+
args: Namespace, model_path: Path, multicomponent: bool, output_path: Path
|
83 |
+
):
|
84 |
+
model = load_model(model_path, multicomponent)
|
85 |
+
model.eval()
|
86 |
+
|
87 |
+
bounded = any(
|
88 |
+
isinstance(model.criterion, LossFunctionRegistry[loss_function])
|
89 |
+
for loss_function in LossFunctionRegistry.keys()
|
90 |
+
if "bounded" in loss_function
|
91 |
+
)
|
92 |
+
|
93 |
+
format_kwargs = dict(
|
94 |
+
no_header_row=args.no_header_row,
|
95 |
+
smiles_cols=args.smiles_columns,
|
96 |
+
rxn_cols=args.reaction_columns,
|
97 |
+
target_cols=[],
|
98 |
+
ignore_cols=None,
|
99 |
+
splits_col=None,
|
100 |
+
weight_col=None,
|
101 |
+
bounded=bounded,
|
102 |
+
)
|
103 |
+
|
104 |
+
featurization_kwargs = dict(
|
105 |
+
molecule_featurizers=args.molecule_featurizers,
|
106 |
+
keep_h=args.keep_h,
|
107 |
+
add_h=args.add_h,
|
108 |
+
ignore_chirality=args.ignore_chirality,
|
109 |
+
)
|
110 |
+
|
111 |
+
test_data = build_data_from_files(
|
112 |
+
args.test_path,
|
113 |
+
**format_kwargs,
|
114 |
+
p_descriptors=args.descriptors_path,
|
115 |
+
p_atom_feats=args.atom_features_path,
|
116 |
+
p_bond_feats=args.bond_features_path,
|
117 |
+
p_atom_descs=args.atom_descriptors_path,
|
118 |
+
**featurization_kwargs,
|
119 |
+
)
|
120 |
+
logger.info(f"test size: {len(test_data[0])}")
|
121 |
+
test_dsets = [
|
122 |
+
make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in test_data
|
123 |
+
]
|
124 |
+
|
125 |
+
if multicomponent:
|
126 |
+
test_dset = data.MulticomponentDataset(test_dsets)
|
127 |
+
else:
|
128 |
+
test_dset = test_dsets[0]
|
129 |
+
|
130 |
+
test_loader = data.build_dataloader(test_dset, args.batch_size, args.num_workers, shuffle=False)
|
131 |
+
|
132 |
+
logger.info(model)
|
133 |
+
|
134 |
+
with torch.no_grad():
|
135 |
+
if multicomponent:
|
136 |
+
encodings = [
|
137 |
+
model.encoding(batch.bmgs, batch.V_ds, batch.X_d, args.ffn_block_index)
|
138 |
+
for batch in test_loader
|
139 |
+
]
|
140 |
+
else:
|
141 |
+
encodings = [
|
142 |
+
model.encoding(batch.bmg, batch.V_d, batch.X_d, args.ffn_block_index)
|
143 |
+
for batch in test_loader
|
144 |
+
]
|
145 |
+
H = torch.cat(encodings, 0).numpy()
|
146 |
+
|
147 |
+
if output_path.suffix in [".npz"]:
|
148 |
+
np.savez(output_path, H=H)
|
149 |
+
elif output_path.suffix == ".csv":
|
150 |
+
fingerprint_columns = [f"fp_{i}" for i in range(H.shape[1])]
|
151 |
+
df_fingerprints = pd.DataFrame(H, columns=fingerprint_columns)
|
152 |
+
df_fingerprints.to_csv(output_path, index=False)
|
153 |
+
else:
|
154 |
+
raise ArgumentError(
|
155 |
+
argument=None, message=f"Output must be a CSV or npz file. Got {args.output}."
|
156 |
+
)
|
157 |
+
logger.info(f"Fingerprints saved to '{output_path}'")
|
158 |
+
|
159 |
+
|
160 |
+
def main(args):
|
161 |
+
match (args.smiles_columns, args.reaction_columns):
|
162 |
+
case [None, None]:
|
163 |
+
n_components = 1
|
164 |
+
case [_, None]:
|
165 |
+
n_components = len(args.smiles_columns)
|
166 |
+
case [None, _]:
|
167 |
+
n_components = len(args.reaction_columns)
|
168 |
+
case _:
|
169 |
+
n_components = len(args.smiles_columns) + len(args.reaction_columns)
|
170 |
+
|
171 |
+
multicomponent = n_components > 1
|
172 |
+
|
173 |
+
for i, model_path in enumerate(find_models(args.model_paths)):
|
174 |
+
logger.info(f"Fingerprints with model {i} at '{model_path}'")
|
175 |
+
output_path = args.output.parent / f"{args.output.stem}_{i}{args.output.suffix}"
|
176 |
+
make_fingerprint_for_model(args, model_path, multicomponent, output_path)
|
177 |
+
|
178 |
+
|
179 |
+
if __name__ == "__main__":
|
180 |
+
parser = ArgumentParser()
|
181 |
+
parser = FingerprintSubcommand.add_args(parser)
|
182 |
+
|
183 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
184 |
+
args = parser.parse_args()
|
185 |
+
args = FingerprintSubcommand.func(args)
|
chemprop-updated/chemprop/cli/hpopt.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from configargparse import ArgumentParser, Namespace
|
8 |
+
from lightning import pytorch as pl
|
9 |
+
from lightning.pytorch.callbacks import EarlyStopping
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
|
14 |
+
from chemprop.cli.train import (
|
15 |
+
TrainSubcommand,
|
16 |
+
add_train_args,
|
17 |
+
build_datasets,
|
18 |
+
build_model,
|
19 |
+
build_splits,
|
20 |
+
normalize_inputs,
|
21 |
+
process_train_args,
|
22 |
+
save_config,
|
23 |
+
validate_train_args,
|
24 |
+
)
|
25 |
+
from chemprop.cli.utils.command import Subcommand
|
26 |
+
from chemprop.data import build_dataloader
|
27 |
+
from chemprop.nn import AggregationRegistry, MetricRegistry
|
28 |
+
from chemprop.nn.transforms import UnscaleTransform
|
29 |
+
from chemprop.nn.utils import Activation
|
30 |
+
|
31 |
+
NO_RAY = False
|
32 |
+
DEFAULT_SEARCH_SPACE = {
|
33 |
+
"activation": None,
|
34 |
+
"aggregation": None,
|
35 |
+
"aggregation_norm": None,
|
36 |
+
"batch_size": None,
|
37 |
+
"depth": None,
|
38 |
+
"dropout": None,
|
39 |
+
"ffn_hidden_dim": None,
|
40 |
+
"ffn_num_layers": None,
|
41 |
+
"final_lr_ratio": None,
|
42 |
+
"message_hidden_dim": None,
|
43 |
+
"init_lr_ratio": None,
|
44 |
+
"max_lr": None,
|
45 |
+
"warmup_epochs": None,
|
46 |
+
}
|
47 |
+
|
48 |
+
try:
|
49 |
+
import ray
|
50 |
+
from ray import tune
|
51 |
+
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
|
52 |
+
from ray.train.lightning import (
|
53 |
+
RayDDPStrategy,
|
54 |
+
RayLightningEnvironment,
|
55 |
+
RayTrainReportCallback,
|
56 |
+
prepare_trainer,
|
57 |
+
)
|
58 |
+
from ray.train.torch import TorchTrainer
|
59 |
+
from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
|
60 |
+
|
61 |
+
DEFAULT_SEARCH_SPACE = {
|
62 |
+
"activation": tune.choice(categories=list(Activation.keys())),
|
63 |
+
"aggregation": tune.choice(categories=list(AggregationRegistry.keys())),
|
64 |
+
"aggregation_norm": tune.quniform(lower=1, upper=200, q=1),
|
65 |
+
"batch_size": tune.choice([16, 32, 64, 128, 256]),
|
66 |
+
"depth": tune.qrandint(lower=2, upper=6, q=1),
|
67 |
+
"dropout": tune.choice([0.0] * 8 + list(np.arange(0.05, 0.45, 0.05))),
|
68 |
+
"ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
|
69 |
+
"ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
|
70 |
+
"final_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
|
71 |
+
"message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
|
72 |
+
"init_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
|
73 |
+
"max_lr": tune.loguniform(lower=1e-4, upper=1e-2),
|
74 |
+
"warmup_epochs": None,
|
75 |
+
}
|
76 |
+
except ImportError:
|
77 |
+
NO_RAY = True
|
78 |
+
|
79 |
+
NO_HYPEROPT = False
|
80 |
+
try:
|
81 |
+
from ray.tune.search.hyperopt import HyperOptSearch
|
82 |
+
except ImportError:
|
83 |
+
NO_HYPEROPT = True
|
84 |
+
|
85 |
+
NO_OPTUNA = False
|
86 |
+
try:
|
87 |
+
from ray.tune.search.optuna import OptunaSearch
|
88 |
+
except ImportError:
|
89 |
+
NO_OPTUNA = True
|
90 |
+
|
91 |
+
|
92 |
+
logger = logging.getLogger(__name__)
|
93 |
+
|
94 |
+
SEARCH_SPACE = DEFAULT_SEARCH_SPACE
|
95 |
+
|
96 |
+
SEARCH_PARAM_KEYWORDS_MAP = {
|
97 |
+
"basic": ["depth", "ffn_num_layers", "dropout", "ffn_hidden_dim", "message_hidden_dim"],
|
98 |
+
"learning_rate": ["max_lr", "init_lr_ratio", "final_lr_ratio", "warmup_epochs"],
|
99 |
+
"all": list(DEFAULT_SEARCH_SPACE.keys()),
|
100 |
+
"init_lr": ["init_lr_ratio"],
|
101 |
+
"final_lr": ["final_lr_ratio"],
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
class HpoptSubcommand(Subcommand):
|
106 |
+
COMMAND = "hpopt"
|
107 |
+
HELP = "Perform hyperparameter optimization on the given task."
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
111 |
+
parser = add_common_args(parser)
|
112 |
+
parser = add_train_args(parser)
|
113 |
+
return add_hpopt_args(parser)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def func(cls, args: Namespace):
|
117 |
+
args = process_common_args(args)
|
118 |
+
args = process_train_args(args)
|
119 |
+
args = process_hpopt_args(args)
|
120 |
+
validate_common_args(args)
|
121 |
+
validate_train_args(args)
|
122 |
+
main(args)
|
123 |
+
|
124 |
+
|
125 |
+
def add_hpopt_args(parser: ArgumentParser) -> ArgumentParser:
|
126 |
+
hpopt_args = parser.add_argument_group("Chemprop hyperparameter optimization arguments")
|
127 |
+
|
128 |
+
hpopt_args.add_argument(
|
129 |
+
"--search-parameter-keywords",
|
130 |
+
type=str,
|
131 |
+
nargs="+",
|
132 |
+
default=["basic"],
|
133 |
+
help=f"""The model parameters over which to search for an optimal hyperparameter configuration. Some options are bundles of parameters or otherwise special parameter operations. Special keywords include:
|
134 |
+
- ``basic``: Default set of hyperparameters for search (depth, ffn_num_layers, dropout, message_hidden_dim, and ffn_hidden_dim)
|
135 |
+
- ``learning_rate``: Search for max_lr, init_lr_ratio, final_lr_ratio, and warmup_epochs. The search for init_lr and final_lr values are defined as fractions of the max_lr value. The search for warmup_epochs is as a fraction of the total epochs used.
|
136 |
+
- ``all``: Include search for all 13 individual keyword options (including: activation, aggregation, aggregation_norm, and batch_size which aren't included in the other two keywords).
|
137 |
+
Individual supported parameters:
|
138 |
+
{list(DEFAULT_SEARCH_SPACE.keys())}
|
139 |
+
""",
|
140 |
+
)
|
141 |
+
|
142 |
+
hpopt_args.add_argument(
|
143 |
+
"--hpopt-save-dir",
|
144 |
+
type=Path,
|
145 |
+
help="Directory to save the hyperparameter optimization results",
|
146 |
+
)
|
147 |
+
|
148 |
+
raytune_args = parser.add_argument_group("Ray Tune arguments")
|
149 |
+
|
150 |
+
raytune_args.add_argument(
|
151 |
+
"--raytune-num-samples",
|
152 |
+
type=int,
|
153 |
+
default=10,
|
154 |
+
help="Passed directly to Ray Tune ``TuneConfig`` to control number of trials to run",
|
155 |
+
)
|
156 |
+
|
157 |
+
raytune_args.add_argument(
|
158 |
+
"--raytune-search-algorithm",
|
159 |
+
choices=["random", "hyperopt", "optuna"],
|
160 |
+
default="hyperopt",
|
161 |
+
help="Passed to Ray Tune ``TuneConfig`` to control search algorithm",
|
162 |
+
)
|
163 |
+
|
164 |
+
raytune_args.add_argument(
|
165 |
+
"--raytune-trial-scheduler",
|
166 |
+
choices=["FIFO", "AsyncHyperBand"],
|
167 |
+
default="FIFO",
|
168 |
+
help="Passed to Ray Tune ``TuneConfig`` to control trial scheduler",
|
169 |
+
)
|
170 |
+
|
171 |
+
raytune_args.add_argument(
|
172 |
+
"--raytune-num-workers",
|
173 |
+
type=int,
|
174 |
+
default=1,
|
175 |
+
help="Passed directly to Ray Tune ``ScalingConfig`` to control number of workers to use",
|
176 |
+
)
|
177 |
+
|
178 |
+
raytune_args.add_argument(
|
179 |
+
"--raytune-use-gpu",
|
180 |
+
action="store_true",
|
181 |
+
help="Passed directly to Ray Tune ``ScalingConfig`` to control whether to use GPUs",
|
182 |
+
)
|
183 |
+
|
184 |
+
raytune_args.add_argument(
|
185 |
+
"--raytune-num-checkpoints-to-keep",
|
186 |
+
type=int,
|
187 |
+
default=1,
|
188 |
+
help="Passed directly to Ray Tune ``CheckpointConfig`` to control number of checkpoints to keep",
|
189 |
+
)
|
190 |
+
|
191 |
+
raytune_args.add_argument(
|
192 |
+
"--raytune-grace-period",
|
193 |
+
type=int,
|
194 |
+
default=10,
|
195 |
+
help="Passed directly to Ray Tune ``ASHAScheduler`` to control grace period",
|
196 |
+
)
|
197 |
+
|
198 |
+
raytune_args.add_argument(
|
199 |
+
"--raytune-reduction-factor",
|
200 |
+
type=int,
|
201 |
+
default=2,
|
202 |
+
help="Passed directly to Ray Tune ``ASHAScheduler`` to control reduction factor",
|
203 |
+
)
|
204 |
+
|
205 |
+
raytune_args.add_argument(
|
206 |
+
"--raytune-temp-dir", help="Passed directly to Ray Tune init to control temporary directory"
|
207 |
+
)
|
208 |
+
|
209 |
+
raytune_args.add_argument(
|
210 |
+
"--raytune-num-cpus",
|
211 |
+
type=int,
|
212 |
+
help="Passed directly to Ray Tune init to control number of CPUs to use",
|
213 |
+
)
|
214 |
+
|
215 |
+
raytune_args.add_argument(
|
216 |
+
"--raytune-num-gpus",
|
217 |
+
type=int,
|
218 |
+
help="Passed directly to Ray Tune init to control number of GPUs to use",
|
219 |
+
)
|
220 |
+
|
221 |
+
raytune_args.add_argument(
|
222 |
+
"--raytune-max-concurrent-trials",
|
223 |
+
type=int,
|
224 |
+
help="Passed directly to Ray Tune TuneConfig to control maximum concurrent trials",
|
225 |
+
)
|
226 |
+
|
227 |
+
hyperopt_args = parser.add_argument_group("Hyperopt arguments")
|
228 |
+
|
229 |
+
hyperopt_args.add_argument(
|
230 |
+
"--hyperopt-n-initial-points",
|
231 |
+
type=int,
|
232 |
+
help="Passed directly to ``HyperOptSearch`` to control number of initial points to sample",
|
233 |
+
)
|
234 |
+
|
235 |
+
hyperopt_args.add_argument(
|
236 |
+
"--hyperopt-random-state-seed",
|
237 |
+
type=int,
|
238 |
+
default=None,
|
239 |
+
help="Passed directly to ``HyperOptSearch`` to control random state seed",
|
240 |
+
)
|
241 |
+
|
242 |
+
return parser
|
243 |
+
|
244 |
+
|
245 |
+
def process_hpopt_args(args: Namespace) -> Namespace:
|
246 |
+
if args.hpopt_save_dir is None:
|
247 |
+
args.hpopt_save_dir = Path(f"chemprop_hpopt/{args.data_path.stem}")
|
248 |
+
|
249 |
+
args.hpopt_save_dir.mkdir(exist_ok=True, parents=True)
|
250 |
+
|
251 |
+
search_parameters = set()
|
252 |
+
|
253 |
+
available_search_parameters = list(SEARCH_SPACE.keys()) + list(SEARCH_PARAM_KEYWORDS_MAP.keys())
|
254 |
+
|
255 |
+
for keyword in args.search_parameter_keywords:
|
256 |
+
if keyword not in available_search_parameters:
|
257 |
+
raise ValueError(
|
258 |
+
f"Search parameter keyword: {keyword} not in available options: {available_search_parameters}."
|
259 |
+
)
|
260 |
+
|
261 |
+
search_parameters.update(
|
262 |
+
SEARCH_PARAM_KEYWORDS_MAP[keyword]
|
263 |
+
if keyword in SEARCH_PARAM_KEYWORDS_MAP
|
264 |
+
else [keyword]
|
265 |
+
)
|
266 |
+
|
267 |
+
args.search_parameter_keywords = list(search_parameters)
|
268 |
+
|
269 |
+
if not args.hyperopt_n_initial_points:
|
270 |
+
args.hyperopt_n_initial_points = args.raytune_num_samples // 2
|
271 |
+
|
272 |
+
return args
|
273 |
+
|
274 |
+
|
275 |
+
def build_search_space(search_parameters: list[str], train_epochs: int) -> dict:
|
276 |
+
if "warmup_epochs" in search_parameters and SEARCH_SPACE.get("warmup_epochs", None) is None:
|
277 |
+
assert (
|
278 |
+
train_epochs >= 6
|
279 |
+
), "Training epochs must be at least 6 to perform hyperparameter optimization for warmup_epochs."
|
280 |
+
SEARCH_SPACE["warmup_epochs"] = tune.qrandint(lower=1, upper=train_epochs // 2, q=1)
|
281 |
+
|
282 |
+
return {param: SEARCH_SPACE[param] for param in search_parameters}
|
283 |
+
|
284 |
+
|
285 |
+
def update_args_with_config(args: Namespace, config: dict) -> Namespace:
|
286 |
+
args = deepcopy(args)
|
287 |
+
|
288 |
+
for key, value in config.items():
|
289 |
+
match key:
|
290 |
+
case "final_lr_ratio":
|
291 |
+
setattr(args, "final_lr", value * config.get("max_lr", args.max_lr))
|
292 |
+
|
293 |
+
case "init_lr_ratio":
|
294 |
+
setattr(args, "init_lr", value * config.get("max_lr", args.max_lr))
|
295 |
+
|
296 |
+
case _:
|
297 |
+
assert key in args, f"Key: {key} not found in args."
|
298 |
+
setattr(args, key, value)
|
299 |
+
|
300 |
+
return args
|
301 |
+
|
302 |
+
|
303 |
+
def train_model(config, args, train_dset, val_dset, logger, output_transform, input_transforms):
|
304 |
+
args = update_args_with_config(args, config)
|
305 |
+
|
306 |
+
train_loader = build_dataloader(
|
307 |
+
train_dset, args.batch_size, args.num_workers, seed=args.data_seed
|
308 |
+
)
|
309 |
+
val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
|
310 |
+
|
311 |
+
seed = args.pytorch_seed if args.pytorch_seed is not None else torch.seed()
|
312 |
+
|
313 |
+
torch.manual_seed(seed)
|
314 |
+
|
315 |
+
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
|
316 |
+
logger.info(model)
|
317 |
+
|
318 |
+
if args.tracking_metric == "val_loss":
|
319 |
+
T_tracking_metric = model.criterion.__class__
|
320 |
+
else:
|
321 |
+
T_tracking_metric = MetricRegistry[args.tracking_metric]
|
322 |
+
args.tracking_metric = "val/" + args.tracking_metric
|
323 |
+
|
324 |
+
monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
|
325 |
+
logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
|
326 |
+
|
327 |
+
patience = args.patience if args.patience is not None else args.epochs
|
328 |
+
early_stopping = EarlyStopping(args.tracking_metric, patience=patience, mode=monitor_mode)
|
329 |
+
|
330 |
+
trainer = pl.Trainer(
|
331 |
+
accelerator=args.accelerator,
|
332 |
+
devices=args.devices,
|
333 |
+
max_epochs=args.epochs,
|
334 |
+
gradient_clip_val=args.grad_clip,
|
335 |
+
strategy=RayDDPStrategy(),
|
336 |
+
callbacks=[RayTrainReportCallback(), early_stopping],
|
337 |
+
plugins=[RayLightningEnvironment()],
|
338 |
+
deterministic=args.pytorch_seed is not None,
|
339 |
+
)
|
340 |
+
trainer = prepare_trainer(trainer)
|
341 |
+
trainer.fit(model, train_loader, val_loader)
|
342 |
+
|
343 |
+
|
344 |
+
def tune_model(
|
345 |
+
args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
|
346 |
+
):
|
347 |
+
match args.raytune_trial_scheduler:
|
348 |
+
case "FIFO":
|
349 |
+
scheduler = FIFOScheduler()
|
350 |
+
case "AsyncHyperBand":
|
351 |
+
scheduler = ASHAScheduler(
|
352 |
+
max_t=args.epochs,
|
353 |
+
grace_period=min(args.raytune_grace_period, args.epochs),
|
354 |
+
reduction_factor=args.raytune_reduction_factor,
|
355 |
+
)
|
356 |
+
case _:
|
357 |
+
raise ValueError(f"Invalid trial scheduler! got: {args.raytune_trial_scheduler}.")
|
358 |
+
|
359 |
+
resources_per_worker = {}
|
360 |
+
if args.raytune_num_cpus and args.raytune_max_concurrent_trials:
|
361 |
+
resources_per_worker["CPU"] = args.raytune_num_cpus / args.raytune_max_concurrent_trials
|
362 |
+
if args.raytune_num_gpus and args.raytune_max_concurrent_trials:
|
363 |
+
resources_per_worker["GPU"] = args.raytune_num_gpus / args.raytune_max_concurrent_trials
|
364 |
+
if not resources_per_worker:
|
365 |
+
resources_per_worker = None
|
366 |
+
|
367 |
+
if args.raytune_num_gpus:
|
368 |
+
use_gpu = True
|
369 |
+
else:
|
370 |
+
use_gpu = args.raytune_use_gpu
|
371 |
+
|
372 |
+
scaling_config = ScalingConfig(
|
373 |
+
num_workers=args.raytune_num_workers,
|
374 |
+
use_gpu=use_gpu,
|
375 |
+
resources_per_worker=resources_per_worker,
|
376 |
+
trainer_resources={"CPU": 0},
|
377 |
+
)
|
378 |
+
|
379 |
+
checkpoint_config = CheckpointConfig(
|
380 |
+
num_to_keep=args.raytune_num_checkpoints_to_keep,
|
381 |
+
checkpoint_score_attribute=args.tracking_metric,
|
382 |
+
checkpoint_score_order=monitor_mode,
|
383 |
+
)
|
384 |
+
|
385 |
+
run_config = RunConfig(
|
386 |
+
checkpoint_config=checkpoint_config,
|
387 |
+
storage_path=args.hpopt_save_dir.absolute() / "ray_results",
|
388 |
+
)
|
389 |
+
|
390 |
+
ray_trainer = TorchTrainer(
|
391 |
+
lambda config: train_model(
|
392 |
+
config, args, train_dset, val_dset, logger, output_transform, input_transforms
|
393 |
+
),
|
394 |
+
scaling_config=scaling_config,
|
395 |
+
run_config=run_config,
|
396 |
+
)
|
397 |
+
|
398 |
+
match args.raytune_search_algorithm:
|
399 |
+
case "random":
|
400 |
+
search_alg = None
|
401 |
+
case "hyperopt":
|
402 |
+
if NO_HYPEROPT:
|
403 |
+
raise ImportError(
|
404 |
+
"HyperOptSearch requires hyperopt to be installed. Use 'pip install -U hyperopt' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
|
405 |
+
)
|
406 |
+
|
407 |
+
search_alg = HyperOptSearch(
|
408 |
+
n_initial_points=args.hyperopt_n_initial_points,
|
409 |
+
random_state_seed=args.hyperopt_random_state_seed,
|
410 |
+
)
|
411 |
+
case "optuna":
|
412 |
+
if NO_OPTUNA:
|
413 |
+
raise ImportError(
|
414 |
+
"OptunaSearch requires optuna to be installed. Use 'pip install -U optuna' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
|
415 |
+
)
|
416 |
+
|
417 |
+
search_alg = OptunaSearch()
|
418 |
+
|
419 |
+
tune_config = tune.TuneConfig(
|
420 |
+
metric=args.tracking_metric,
|
421 |
+
mode=monitor_mode,
|
422 |
+
num_samples=args.raytune_num_samples,
|
423 |
+
scheduler=scheduler,
|
424 |
+
search_alg=search_alg,
|
425 |
+
trial_dirname_creator=lambda trial: str(trial.trial_id),
|
426 |
+
)
|
427 |
+
|
428 |
+
tuner = tune.Tuner(
|
429 |
+
ray_trainer,
|
430 |
+
param_space={
|
431 |
+
"train_loop_config": build_search_space(args.search_parameter_keywords, args.epochs)
|
432 |
+
},
|
433 |
+
tune_config=tune_config,
|
434 |
+
)
|
435 |
+
|
436 |
+
return tuner.fit()
|
437 |
+
|
438 |
+
|
439 |
+
def main(args: Namespace):
|
440 |
+
if NO_RAY:
|
441 |
+
raise ImportError(
|
442 |
+
"Ray Tune requires ray to be installed. If you installed Chemprop from PyPI, run 'pip install -U ray[tune]' to install ray. If you installed from source, use 'pip install -e .[hpopt]' in Chemprop folder to install all hpopt relevant packages."
|
443 |
+
)
|
444 |
+
|
445 |
+
if not ray.is_initialized():
|
446 |
+
try:
|
447 |
+
ray.init(
|
448 |
+
_temp_dir=args.raytune_temp_dir,
|
449 |
+
num_cpus=args.raytune_num_cpus,
|
450 |
+
num_gpus=args.raytune_num_gpus,
|
451 |
+
)
|
452 |
+
except OSError as e:
|
453 |
+
if "AF_UNIX path length cannot exceed 107 bytes" in str(e):
|
454 |
+
raise OSError(
|
455 |
+
f"Ray Tune fails due to: {e}. This can sometimes be solved by providing a temporary directory, num_cpus, and num_gpus to Ray Tune via the CLI: --raytune-temp-dir <absolute_path> --raytune-num-cpus <int> --raytune-num-gpus <int>."
|
456 |
+
)
|
457 |
+
else:
|
458 |
+
raise e
|
459 |
+
else:
|
460 |
+
logger.info("Ray is already initialized.")
|
461 |
+
|
462 |
+
format_kwargs = dict(
|
463 |
+
no_header_row=args.no_header_row,
|
464 |
+
smiles_cols=args.smiles_columns,
|
465 |
+
rxn_cols=args.reaction_columns,
|
466 |
+
target_cols=args.target_columns,
|
467 |
+
ignore_cols=args.ignore_columns,
|
468 |
+
splits_col=args.splits_column,
|
469 |
+
weight_col=args.weight_column,
|
470 |
+
bounded=args.loss_function is not None and "bounded" in args.loss_function,
|
471 |
+
)
|
472 |
+
|
473 |
+
featurization_kwargs = dict(
|
474 |
+
molecule_featurizers=args.molecule_featurizers,
|
475 |
+
keep_h=args.keep_h,
|
476 |
+
add_h=args.add_h,
|
477 |
+
ignore_chirality=args.ignore_chirality,
|
478 |
+
)
|
479 |
+
|
480 |
+
train_data, val_data, test_data = build_splits(args, format_kwargs, featurization_kwargs)
|
481 |
+
train_dset, val_dset, test_dset = build_datasets(args, train_data[0], val_data[0], test_data[0])
|
482 |
+
|
483 |
+
input_transforms = normalize_inputs(train_dset, val_dset, args)
|
484 |
+
|
485 |
+
if "regression" in args.task_type:
|
486 |
+
output_scaler = train_dset.normalize_targets()
|
487 |
+
val_dset.normalize_targets(output_scaler)
|
488 |
+
logger.info(f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}")
|
489 |
+
output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
|
490 |
+
else:
|
491 |
+
output_transform = None
|
492 |
+
|
493 |
+
train_loader = build_dataloader(
|
494 |
+
train_dset, args.batch_size, args.num_workers, seed=args.data_seed
|
495 |
+
)
|
496 |
+
|
497 |
+
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
|
498 |
+
monitor_mode = "max" if model.metrics[0].higher_is_better else "min"
|
499 |
+
|
500 |
+
results = tune_model(
|
501 |
+
args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
|
502 |
+
)
|
503 |
+
|
504 |
+
best_result = results.get_best_result()
|
505 |
+
best_config = best_result.config["train_loop_config"]
|
506 |
+
best_checkpoint_path = Path(best_result.checkpoint.path) / "checkpoint.ckpt"
|
507 |
+
|
508 |
+
best_config_save_path = args.hpopt_save_dir / "best_config.toml"
|
509 |
+
best_checkpoint_save_path = args.hpopt_save_dir / "best_checkpoint.ckpt"
|
510 |
+
all_progress_save_path = args.hpopt_save_dir / "all_progress.csv"
|
511 |
+
|
512 |
+
logger.info(f"Best hyperparameters saved to: '{best_config_save_path}'")
|
513 |
+
|
514 |
+
args = update_args_with_config(args, best_config)
|
515 |
+
|
516 |
+
args = TrainSubcommand.parser.parse_known_args(namespace=args)[0]
|
517 |
+
save_config(TrainSubcommand.parser, args, best_config_save_path)
|
518 |
+
|
519 |
+
logger.info(
|
520 |
+
f"Best hyperparameter configuration checkpoint saved to '{best_checkpoint_save_path}'"
|
521 |
+
)
|
522 |
+
|
523 |
+
shutil.copyfile(best_checkpoint_path, best_checkpoint_save_path)
|
524 |
+
|
525 |
+
logger.info(f"Hyperparameter optimization results saved to '{all_progress_save_path}'")
|
526 |
+
|
527 |
+
result_df = results.get_dataframe()
|
528 |
+
|
529 |
+
result_df.to_csv(all_progress_save_path, index=False)
|
530 |
+
|
531 |
+
ray.shutdown()
|
532 |
+
|
533 |
+
|
534 |
+
if __name__ == "__main__":
|
535 |
+
parser = ArgumentParser()
|
536 |
+
parser = HpoptSubcommand.add_args(parser)
|
537 |
+
|
538 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
539 |
+
args = parser.parse_args()
|
540 |
+
HpoptSubcommand.func(args)
|
chemprop-updated/chemprop/cli/main.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from configargparse import ArgumentParser
|
6 |
+
|
7 |
+
from chemprop.cli.conf import LOG_DIR, LOG_LEVELS, NOW
|
8 |
+
from chemprop.cli.convert import ConvertSubcommand
|
9 |
+
from chemprop.cli.fingerprint import FingerprintSubcommand
|
10 |
+
from chemprop.cli.hpopt import HpoptSubcommand
|
11 |
+
from chemprop.cli.predict import PredictSubcommand
|
12 |
+
from chemprop.cli.train import TrainSubcommand
|
13 |
+
from chemprop.cli.utils import pop_attr
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
SUBCOMMANDS = [
|
18 |
+
TrainSubcommand,
|
19 |
+
PredictSubcommand,
|
20 |
+
ConvertSubcommand,
|
21 |
+
FingerprintSubcommand,
|
22 |
+
HpoptSubcommand,
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
def construct_parser():
|
27 |
+
parser = ArgumentParser()
|
28 |
+
subparsers = parser.add_subparsers(title="mode", dest="mode", required=True)
|
29 |
+
|
30 |
+
parent = ArgumentParser(add_help=False)
|
31 |
+
parent.add_argument(
|
32 |
+
"--logfile",
|
33 |
+
"--log",
|
34 |
+
nargs="?",
|
35 |
+
const="default",
|
36 |
+
help=f"Path to which the log file should be written (specifying just the flag alone will automatically log to a file ``{LOG_DIR}/MODE/TIMESTAMP.log`` , where 'MODE' is the CLI mode chosen, e.g., ``{LOG_DIR}/MODE/{NOW}.log``)",
|
37 |
+
)
|
38 |
+
parent.add_argument("-v", action="store_true", help="Increase verbosity level to DEBUG")
|
39 |
+
parent.add_argument(
|
40 |
+
"-q",
|
41 |
+
action="count",
|
42 |
+
default=0,
|
43 |
+
help="Decrease verbosity level to WARNING or ERROR if specified twice",
|
44 |
+
)
|
45 |
+
|
46 |
+
parents = [parent]
|
47 |
+
for subcommand in SUBCOMMANDS:
|
48 |
+
subcommand.add(subparsers, parents)
|
49 |
+
|
50 |
+
return parser
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
parser = construct_parser()
|
55 |
+
args = parser.parse_args()
|
56 |
+
logfile, v_flag, q_count, mode, func = (
|
57 |
+
pop_attr(args, attr) for attr in ["logfile", "v", "q", "mode", "func"]
|
58 |
+
)
|
59 |
+
|
60 |
+
if v_flag and q_count:
|
61 |
+
parser.error("The -v and -q options cannot be used together.")
|
62 |
+
|
63 |
+
match logfile:
|
64 |
+
case None:
|
65 |
+
handler = logging.StreamHandler(sys.stderr)
|
66 |
+
case "default":
|
67 |
+
(LOG_DIR / mode).mkdir(parents=True, exist_ok=True)
|
68 |
+
handler = logging.FileHandler(str(LOG_DIR / mode / f"{NOW}.log"))
|
69 |
+
case _:
|
70 |
+
Path(logfile).parent.mkdir(parents=True, exist_ok=True)
|
71 |
+
handler = logging.FileHandler(logfile)
|
72 |
+
|
73 |
+
verbosity = q_count * -1 if q_count else (1 if v_flag else 0)
|
74 |
+
logging_level = LOG_LEVELS.get(verbosity, logging.ERROR)
|
75 |
+
logging.basicConfig(
|
76 |
+
handlers=[handler],
|
77 |
+
format="%(asctime)s - %(levelname)s:%(name)s - %(message)s",
|
78 |
+
level=logging_level,
|
79 |
+
datefmt="%Y-%m-%dT%H:%M:%S",
|
80 |
+
force=True,
|
81 |
+
)
|
82 |
+
|
83 |
+
logger.info(f"Running in mode '{mode}' with args: {vars(args)}")
|
84 |
+
|
85 |
+
func(args)
|
chemprop-updated/chemprop/cli/predict.py
ADDED
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
from typing import Iterator
|
6 |
+
|
7 |
+
from lightning import pytorch as pl
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from chemprop import data
|
13 |
+
from chemprop.cli.common import (
|
14 |
+
add_common_args,
|
15 |
+
find_models,
|
16 |
+
process_common_args,
|
17 |
+
validate_common_args,
|
18 |
+
)
|
19 |
+
from chemprop.cli.utils import LookupAction, Subcommand, build_data_from_files, make_dataset
|
20 |
+
from chemprop.models.utils import load_model, load_output_columns
|
21 |
+
from chemprop.nn.metrics import LossFunctionRegistry
|
22 |
+
from chemprop.nn.predictors import EvidentialFFN, MulticlassClassificationFFN, MveFFN
|
23 |
+
from chemprop.uncertainty import (
|
24 |
+
MVEWeightingCalibrator,
|
25 |
+
NoUncertaintyEstimator,
|
26 |
+
RegressionCalibrator,
|
27 |
+
RegressionEvaluator,
|
28 |
+
UncertaintyCalibratorRegistry,
|
29 |
+
UncertaintyEstimatorRegistry,
|
30 |
+
UncertaintyEvaluatorRegistry,
|
31 |
+
)
|
32 |
+
from chemprop.utils import Factory
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class PredictSubcommand(Subcommand):
|
38 |
+
COMMAND = "predict"
|
39 |
+
HELP = "use a pretrained chemprop model for prediction"
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
43 |
+
parser = add_common_args(parser)
|
44 |
+
return add_predict_args(parser)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def func(cls, args: Namespace):
|
48 |
+
args = process_common_args(args)
|
49 |
+
validate_common_args(args)
|
50 |
+
args = process_predict_args(args)
|
51 |
+
main(args)
|
52 |
+
|
53 |
+
|
54 |
+
def add_predict_args(parser: ArgumentParser) -> ArgumentParser:
|
55 |
+
parser.add_argument(
|
56 |
+
"-i",
|
57 |
+
"--test-path",
|
58 |
+
required=True,
|
59 |
+
type=Path,
|
60 |
+
help="Path to an input CSV file containing SMILES",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"-o",
|
64 |
+
"--output",
|
65 |
+
"--preds-path",
|
66 |
+
type=Path,
|
67 |
+
help="Specify path to which predictions will be saved. If the file extension is .pkl, it will be saved as a pickle file. Otherwise, chemprop will save predictions as a CSV. If multiple models are used to make predictions, the average predictions will be saved in the file, and another file ending in '_individual' with the same file extension will save the predictions for each individual model, with the column names being the target names appended with the model index (e.g., '_model_<index>').",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--drop-extra-columns",
|
71 |
+
action="store_true",
|
72 |
+
help="Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--model-paths",
|
76 |
+
"--model-path",
|
77 |
+
required=True,
|
78 |
+
type=Path,
|
79 |
+
nargs="+",
|
80 |
+
help="Location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, will recursively search and predict on all found (.pt) models.",
|
81 |
+
)
|
82 |
+
|
83 |
+
unc_args = parser.add_argument_group("Uncertainty and calibration args")
|
84 |
+
unc_args.add_argument(
|
85 |
+
"--cal-path", type=Path, help="Path to data file to be used for uncertainty calibration."
|
86 |
+
)
|
87 |
+
unc_args.add_argument(
|
88 |
+
"--uncertainty-method",
|
89 |
+
default="none",
|
90 |
+
action=LookupAction(UncertaintyEstimatorRegistry),
|
91 |
+
help="The method of calculating uncertainty.",
|
92 |
+
)
|
93 |
+
unc_args.add_argument(
|
94 |
+
"--calibration-method",
|
95 |
+
action=LookupAction(UncertaintyCalibratorRegistry),
|
96 |
+
help="The method used for calibrating the uncertainty calculated with uncertainty method.",
|
97 |
+
)
|
98 |
+
unc_args.add_argument(
|
99 |
+
"--evaluation-methods",
|
100 |
+
"--evaluation-method",
|
101 |
+
nargs="+",
|
102 |
+
action=LookupAction(UncertaintyEvaluatorRegistry),
|
103 |
+
help="The methods used for evaluating the uncertainty performance if the test data provided includes targets. Available methods are [nll, miscalibration_area, ence, spearman] or any available classification or multiclass metric.",
|
104 |
+
)
|
105 |
+
# unc_args.add_argument(
|
106 |
+
# "--evaluation-scores-path", help="Location to save the results of uncertainty evaluations."
|
107 |
+
# )
|
108 |
+
unc_args.add_argument(
|
109 |
+
"--uncertainty-dropout-p",
|
110 |
+
type=float,
|
111 |
+
default=0.1,
|
112 |
+
help="The probability to use for Monte Carlo dropout uncertainty estimation.",
|
113 |
+
)
|
114 |
+
unc_args.add_argument(
|
115 |
+
"--dropout-sampling-size",
|
116 |
+
type=int,
|
117 |
+
default=10,
|
118 |
+
help="The number of samples to use for Monte Carlo dropout uncertainty estimation. Distinct from the dropout used during training.",
|
119 |
+
)
|
120 |
+
unc_args.add_argument(
|
121 |
+
"--calibration-interval-percentile",
|
122 |
+
type=float,
|
123 |
+
default=95,
|
124 |
+
help="Sets the percentile used in the calibration methods. Must be in the range (1, 100).",
|
125 |
+
)
|
126 |
+
unc_args.add_argument(
|
127 |
+
"--conformal-alpha",
|
128 |
+
type=float,
|
129 |
+
default=0.1,
|
130 |
+
help="Target error rate for conformal prediction. Must be in the range (0, 1).",
|
131 |
+
)
|
132 |
+
# TODO: Decide if we want to implment this in v2.1.x
|
133 |
+
# unc_args.add_argument(
|
134 |
+
# "--regression-calibrator-metric",
|
135 |
+
# choices=["stdev", "interval"],
|
136 |
+
# help="Regression calibrators can output either a stdev or an inverval.",
|
137 |
+
# )
|
138 |
+
unc_args.add_argument(
|
139 |
+
"--cal-descriptors-path",
|
140 |
+
nargs="+",
|
141 |
+
action="append",
|
142 |
+
help="Path to extra descriptors to concatenate to learned representation in calibration dataset.",
|
143 |
+
)
|
144 |
+
# TODO: Add in v2.1.x
|
145 |
+
# unc_args.add_argument(
|
146 |
+
# "--calibration-phase-features-path",
|
147 |
+
# help=" ",
|
148 |
+
# )
|
149 |
+
unc_args.add_argument(
|
150 |
+
"--cal-atom-features-path",
|
151 |
+
nargs="+",
|
152 |
+
action="append",
|
153 |
+
help="Path to the extra atom features in calibration dataset.",
|
154 |
+
)
|
155 |
+
unc_args.add_argument(
|
156 |
+
"--cal-atom-descriptors-path",
|
157 |
+
nargs="+",
|
158 |
+
action="append",
|
159 |
+
help="Path to the extra atom descriptors in calibration dataset.",
|
160 |
+
)
|
161 |
+
unc_args.add_argument(
|
162 |
+
"--cal-bond-features-path",
|
163 |
+
nargs="+",
|
164 |
+
action="append",
|
165 |
+
help="Path to the extra bond descriptors in calibration dataset.",
|
166 |
+
)
|
167 |
+
|
168 |
+
return parser
|
169 |
+
|
170 |
+
|
171 |
+
def process_predict_args(args: Namespace) -> Namespace:
|
172 |
+
if args.test_path.suffix not in [".csv"]:
|
173 |
+
raise ArgumentError(
|
174 |
+
argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
|
175 |
+
)
|
176 |
+
if args.output is None:
|
177 |
+
args.output = args.test_path.parent / (args.test_path.stem + "_preds.csv")
|
178 |
+
if args.output.suffix not in [".csv", ".pkl"]:
|
179 |
+
raise ArgumentError(
|
180 |
+
argument=None, message=f"Output must be a CSV or Pickle file. Got {args.output}"
|
181 |
+
)
|
182 |
+
return args
|
183 |
+
|
184 |
+
|
185 |
+
def prepare_data_loader(
|
186 |
+
args: Namespace, multicomponent: bool, is_calibration: bool, format_kwargs: dict
|
187 |
+
):
|
188 |
+
data_path = args.cal_path if is_calibration else args.test_path
|
189 |
+
descriptors_path = args.cal_descriptors_path if is_calibration else args.descriptors_path
|
190 |
+
atom_feats_path = args.cal_atom_features_path if is_calibration else args.atom_features_path
|
191 |
+
bond_feats_path = args.cal_bond_features_path if is_calibration else args.bond_features_path
|
192 |
+
atom_descs_path = (
|
193 |
+
args.cal_atom_descriptors_path if is_calibration else args.atom_descriptors_path
|
194 |
+
)
|
195 |
+
|
196 |
+
featurization_kwargs = dict(
|
197 |
+
molecule_featurizers=args.molecule_featurizers,
|
198 |
+
keep_h=args.keep_h,
|
199 |
+
add_h=args.add_h,
|
200 |
+
ignore_chirality=args.ignore_chirality,
|
201 |
+
)
|
202 |
+
|
203 |
+
datas = build_data_from_files(
|
204 |
+
data_path,
|
205 |
+
**format_kwargs,
|
206 |
+
p_descriptors=descriptors_path,
|
207 |
+
p_atom_feats=atom_feats_path,
|
208 |
+
p_bond_feats=bond_feats_path,
|
209 |
+
p_atom_descs=atom_descs_path,
|
210 |
+
**featurization_kwargs,
|
211 |
+
)
|
212 |
+
|
213 |
+
dsets = [make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in datas]
|
214 |
+
dset = data.MulticomponentDataset(dsets) if multicomponent else dsets[0]
|
215 |
+
|
216 |
+
return data.build_dataloader(dset, args.batch_size, args.num_workers, shuffle=False)
|
217 |
+
|
218 |
+
|
219 |
+
def make_prediction_for_models(
|
220 |
+
args: Namespace, model_paths: Iterator[Path], multicomponent: bool, output_path: Path
|
221 |
+
):
|
222 |
+
model = load_model(model_paths[0], multicomponent)
|
223 |
+
output_columns = load_output_columns(model_paths[0])
|
224 |
+
bounded = any(
|
225 |
+
isinstance(model.criterion, LossFunctionRegistry[loss_function])
|
226 |
+
for loss_function in LossFunctionRegistry.keys()
|
227 |
+
if "bounded" in loss_function
|
228 |
+
)
|
229 |
+
format_kwargs = dict(
|
230 |
+
no_header_row=args.no_header_row,
|
231 |
+
smiles_cols=args.smiles_columns,
|
232 |
+
rxn_cols=args.reaction_columns,
|
233 |
+
ignore_cols=None,
|
234 |
+
splits_col=None,
|
235 |
+
weight_col=None,
|
236 |
+
bounded=bounded,
|
237 |
+
)
|
238 |
+
format_kwargs["target_cols"] = output_columns if args.evaluation_methods is not None else []
|
239 |
+
test_loader = prepare_data_loader(args, multicomponent, False, format_kwargs)
|
240 |
+
logger.info(f"test size: {len(test_loader.dataset)}")
|
241 |
+
if args.cal_path is not None:
|
242 |
+
format_kwargs["target_cols"] = output_columns
|
243 |
+
cal_loader = prepare_data_loader(args, multicomponent, True, format_kwargs)
|
244 |
+
logger.info(f"calibration size: {len(cal_loader.dataset)}")
|
245 |
+
|
246 |
+
uncertainty_estimator = Factory.build(
|
247 |
+
UncertaintyEstimatorRegistry[args.uncertainty_method],
|
248 |
+
ensemble_size=args.dropout_sampling_size,
|
249 |
+
dropout=args.uncertainty_dropout_p,
|
250 |
+
)
|
251 |
+
|
252 |
+
models = [load_model(model_path, multicomponent) for model_path in model_paths]
|
253 |
+
trainer = pl.Trainer(
|
254 |
+
logger=False, enable_progress_bar=True, accelerator=args.accelerator, devices=args.devices
|
255 |
+
)
|
256 |
+
test_individual_preds, test_individual_uncs = uncertainty_estimator(
|
257 |
+
test_loader, models, trainer
|
258 |
+
)
|
259 |
+
test_preds = torch.mean(test_individual_preds, dim=0)
|
260 |
+
if not isinstance(uncertainty_estimator, NoUncertaintyEstimator):
|
261 |
+
test_uncs = torch.mean(test_individual_uncs, dim=0)
|
262 |
+
else:
|
263 |
+
test_uncs = None
|
264 |
+
|
265 |
+
if args.calibration_method is not None:
|
266 |
+
uncertainty_calibrator = Factory.build(
|
267 |
+
UncertaintyCalibratorRegistry[args.calibration_method],
|
268 |
+
p=args.calibration_interval_percentile / 100,
|
269 |
+
alpha=args.conformal_alpha,
|
270 |
+
)
|
271 |
+
cal_targets = cal_loader.dataset.Y
|
272 |
+
cal_mask = torch.from_numpy(np.isfinite(cal_targets))
|
273 |
+
cal_targets = np.nan_to_num(cal_targets, nan=0.0)
|
274 |
+
cal_targets = torch.from_numpy(cal_targets)
|
275 |
+
cal_individual_preds, cal_individual_uncs = uncertainty_estimator(
|
276 |
+
cal_loader, models, trainer
|
277 |
+
)
|
278 |
+
cal_preds = torch.mean(cal_individual_preds, dim=0)
|
279 |
+
cal_uncs = torch.mean(cal_individual_uncs, dim=0)
|
280 |
+
if isinstance(uncertainty_calibrator, MVEWeightingCalibrator):
|
281 |
+
uncertainty_calibrator.fit(cal_preds, cal_individual_uncs, cal_targets, cal_mask)
|
282 |
+
test_uncs = uncertainty_calibrator.apply(cal_individual_uncs)
|
283 |
+
else:
|
284 |
+
if isinstance(uncertainty_calibrator, RegressionCalibrator):
|
285 |
+
uncertainty_calibrator.fit(cal_preds, cal_uncs, cal_targets, cal_mask)
|
286 |
+
else:
|
287 |
+
uncertainty_calibrator.fit(cal_uncs, cal_targets, cal_mask)
|
288 |
+
test_uncs = uncertainty_calibrator.apply(test_uncs)
|
289 |
+
for i in range(test_individual_uncs.shape[0]):
|
290 |
+
test_individual_uncs[i] = uncertainty_calibrator.apply(test_individual_uncs[i])
|
291 |
+
|
292 |
+
if args.evaluation_methods is not None:
|
293 |
+
uncertainty_evaluators = [
|
294 |
+
Factory.build(UncertaintyEvaluatorRegistry[method])
|
295 |
+
for method in args.evaluation_methods
|
296 |
+
]
|
297 |
+
logger.info("Uncertainty evaluation metric:")
|
298 |
+
for evaluator in uncertainty_evaluators:
|
299 |
+
test_targets = test_loader.dataset.Y
|
300 |
+
test_mask = torch.from_numpy(np.isfinite(test_targets))
|
301 |
+
test_targets = np.nan_to_num(test_targets, nan=0.0)
|
302 |
+
test_targets = torch.from_numpy(test_targets)
|
303 |
+
if isinstance(evaluator, RegressionEvaluator):
|
304 |
+
metric_value = evaluator.evaluate(test_preds, test_uncs, test_targets, test_mask)
|
305 |
+
else:
|
306 |
+
metric_value = evaluator.evaluate(test_uncs, test_targets, test_mask)
|
307 |
+
logger.info(f"{evaluator.alias}: {metric_value.tolist()}")
|
308 |
+
|
309 |
+
if args.uncertainty_method == "none" and (
|
310 |
+
isinstance(model.predictor, MveFFN) or isinstance(model.predictor, EvidentialFFN)
|
311 |
+
):
|
312 |
+
test_preds = test_preds[..., 0]
|
313 |
+
test_individual_preds = test_individual_preds[..., 0]
|
314 |
+
|
315 |
+
if output_columns is None:
|
316 |
+
output_columns = [
|
317 |
+
f"pred_{i}" for i in range(test_preds.shape[1])
|
318 |
+
] # TODO: need to improve this for cases like multi-task MVE and multi-task multiclass
|
319 |
+
|
320 |
+
save_predictions(args, model, output_columns, test_preds, test_uncs, output_path)
|
321 |
+
|
322 |
+
if len(model_paths) > 1:
|
323 |
+
save_individual_predictions(
|
324 |
+
args,
|
325 |
+
model,
|
326 |
+
model_paths,
|
327 |
+
output_columns,
|
328 |
+
test_individual_preds,
|
329 |
+
test_individual_uncs,
|
330 |
+
output_path,
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
def save_predictions(args, model, output_columns, test_preds, test_uncs, output_path):
|
335 |
+
unc_columns = [f"{col}_unc" for col in output_columns]
|
336 |
+
|
337 |
+
if isinstance(model.predictor, MulticlassClassificationFFN):
|
338 |
+
output_columns = output_columns + [f"{col}_prob" for col in output_columns]
|
339 |
+
predicted_class_labels = test_preds.argmax(axis=-1)
|
340 |
+
formatted_probability_strings = np.apply_along_axis(
|
341 |
+
lambda x: ",".join(map(str, x)), 2, test_preds.numpy()
|
342 |
+
)
|
343 |
+
test_preds = np.concatenate(
|
344 |
+
(predicted_class_labels, formatted_probability_strings), axis=-1
|
345 |
+
)
|
346 |
+
|
347 |
+
df_test = pd.read_csv(
|
348 |
+
args.test_path, header=None if args.no_header_row else "infer", index_col=False
|
349 |
+
)
|
350 |
+
df_test[output_columns] = test_preds
|
351 |
+
|
352 |
+
if args.uncertainty_method not in ["none", "classification"]:
|
353 |
+
df_test[unc_columns] = np.round(test_uncs, 6)
|
354 |
+
|
355 |
+
if output_path.suffix == ".pkl":
|
356 |
+
df_test = df_test.reset_index(drop=True)
|
357 |
+
df_test.to_pickle(output_path)
|
358 |
+
else:
|
359 |
+
df_test.to_csv(output_path, index=False)
|
360 |
+
logger.info(f"Predictions saved to '{output_path}'")
|
361 |
+
|
362 |
+
|
363 |
+
def save_individual_predictions(
|
364 |
+
args,
|
365 |
+
model,
|
366 |
+
model_paths,
|
367 |
+
output_columns,
|
368 |
+
test_individual_preds,
|
369 |
+
test_individual_uncs,
|
370 |
+
output_path,
|
371 |
+
):
|
372 |
+
unc_columns = [
|
373 |
+
f"{col}_unc_model_{i}" for i in range(len(model_paths)) for col in output_columns
|
374 |
+
]
|
375 |
+
|
376 |
+
if isinstance(model.predictor, MulticlassClassificationFFN):
|
377 |
+
output_columns = [
|
378 |
+
item
|
379 |
+
for i in range(len(model_paths))
|
380 |
+
for col in output_columns
|
381 |
+
for item in (f"{col}_model_{i}", f"{col}_prob_model_{i}")
|
382 |
+
]
|
383 |
+
|
384 |
+
predicted_class_labels = test_individual_preds.argmax(axis=-1)
|
385 |
+
formatted_probability_strings = np.apply_along_axis(
|
386 |
+
lambda x: ",".join(map(str, x)), 3, test_individual_preds.numpy()
|
387 |
+
)
|
388 |
+
test_individual_preds = np.concatenate(
|
389 |
+
(predicted_class_labels, formatted_probability_strings), axis=-1
|
390 |
+
)
|
391 |
+
else:
|
392 |
+
output_columns = [
|
393 |
+
f"{col}_model_{i}" for i in range(len(model_paths)) for col in output_columns
|
394 |
+
]
|
395 |
+
|
396 |
+
m, n, t = test_individual_preds.shape
|
397 |
+
test_individual_preds = np.transpose(test_individual_preds, (1, 0, 2)).reshape(n, m * t)
|
398 |
+
df_test = pd.read_csv(
|
399 |
+
args.test_path, header=None if args.no_header_row else "infer", index_col=False
|
400 |
+
)
|
401 |
+
df_test[output_columns] = test_individual_preds
|
402 |
+
|
403 |
+
if args.uncertainty_method not in ["none", "classification", "ensemble"]:
|
404 |
+
m, n, t = test_individual_uncs.shape
|
405 |
+
test_individual_uncs = np.transpose(test_individual_uncs, (1, 0, 2)).reshape(n, m * t)
|
406 |
+
df_test[unc_columns] = np.round(test_individual_uncs, 6)
|
407 |
+
|
408 |
+
output_path = output_path.parent / Path(
|
409 |
+
str(args.output.stem) + "_individual" + str(output_path.suffix)
|
410 |
+
)
|
411 |
+
if output_path.suffix == ".pkl":
|
412 |
+
df_test = df_test.reset_index(drop=True)
|
413 |
+
df_test.to_pickle(output_path)
|
414 |
+
else:
|
415 |
+
df_test.to_csv(output_path, index=False)
|
416 |
+
logger.info(f"Individual predictions saved to '{output_path}'")
|
417 |
+
for i, model_path in enumerate(model_paths):
|
418 |
+
logger.info(
|
419 |
+
f"Results from model path {model_path} are saved under the column name ending with 'model_{i}'"
|
420 |
+
)
|
421 |
+
|
422 |
+
|
423 |
+
def main(args):
|
424 |
+
match (args.smiles_columns, args.reaction_columns):
|
425 |
+
case [None, None]:
|
426 |
+
n_components = 1
|
427 |
+
case [_, None]:
|
428 |
+
n_components = len(args.smiles_columns)
|
429 |
+
case [None, _]:
|
430 |
+
n_components = len(args.reaction_columns)
|
431 |
+
case _:
|
432 |
+
n_components = len(args.smiles_columns) + len(args.reaction_columns)
|
433 |
+
|
434 |
+
multicomponent = n_components > 1
|
435 |
+
|
436 |
+
model_paths = find_models(args.model_paths)
|
437 |
+
|
438 |
+
make_prediction_for_models(args, model_paths, multicomponent, output_path=args.output)
|
439 |
+
|
440 |
+
|
441 |
+
if __name__ == "__main__":
|
442 |
+
parser = ArgumentParser()
|
443 |
+
parser = PredictSubcommand.add_args(parser)
|
444 |
+
|
445 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
446 |
+
args = parser.parse_args()
|
447 |
+
args = PredictSubcommand.func(args)
|
chemprop-updated/chemprop/cli/train.py
ADDED
@@ -0,0 +1,1343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from io import StringIO
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
from tempfile import TemporaryDirectory
|
8 |
+
|
9 |
+
from configargparse import ArgumentError, ArgumentParser, Namespace
|
10 |
+
from lightning import pytorch as pl
|
11 |
+
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
12 |
+
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
13 |
+
from lightning.pytorch.strategies import DDPStrategy
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from rich.console import Console
|
17 |
+
from rich.table import Column, Table
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from chemprop.cli.common import (
|
22 |
+
add_common_args,
|
23 |
+
find_models,
|
24 |
+
process_common_args,
|
25 |
+
validate_common_args,
|
26 |
+
)
|
27 |
+
from chemprop.cli.conf import CHEMPROP_TRAIN_DIR, NOW
|
28 |
+
from chemprop.cli.utils import (
|
29 |
+
LookupAction,
|
30 |
+
Subcommand,
|
31 |
+
build_data_from_files,
|
32 |
+
get_column_names,
|
33 |
+
make_dataset,
|
34 |
+
parse_indices,
|
35 |
+
)
|
36 |
+
from chemprop.cli.utils.args import uppercase
|
37 |
+
from chemprop.data import (
|
38 |
+
MoleculeDataset,
|
39 |
+
MolGraphDataset,
|
40 |
+
MulticomponentDataset,
|
41 |
+
ReactionDatapoint,
|
42 |
+
SplitType,
|
43 |
+
build_dataloader,
|
44 |
+
make_split_indices,
|
45 |
+
split_data_by_indices,
|
46 |
+
)
|
47 |
+
from chemprop.data.datasets import _MolGraphDatasetMixin
|
48 |
+
from chemprop.models import MPNN, MulticomponentMPNN, save_model
|
49 |
+
from chemprop.nn import AggregationRegistry, LossFunctionRegistry, MetricRegistry, PredictorRegistry
|
50 |
+
from chemprop.nn.message_passing import (
|
51 |
+
AtomMessagePassing,
|
52 |
+
BondMessagePassing,
|
53 |
+
MulticomponentMessagePassing,
|
54 |
+
)
|
55 |
+
from chemprop.nn.transforms import GraphTransform, ScaleTransform, UnscaleTransform
|
56 |
+
from chemprop.nn.utils import Activation
|
57 |
+
from chemprop.utils import Factory
|
58 |
+
|
59 |
+
logger = logging.getLogger(__name__)
|
60 |
+
|
61 |
+
|
62 |
+
_CV_REMOVAL_ERROR = (
|
63 |
+
"The -k/--num-folds argument was removed in v2.1.0 - use --num-replicates instead."
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
class TrainSubcommand(Subcommand):
|
68 |
+
COMMAND = "train"
|
69 |
+
HELP = "Train a chemprop model."
|
70 |
+
parser = None
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
74 |
+
parser = add_common_args(parser)
|
75 |
+
parser = add_train_args(parser)
|
76 |
+
cls.parser = parser
|
77 |
+
return parser
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def func(cls, args: Namespace):
|
81 |
+
args = process_common_args(args)
|
82 |
+
validate_common_args(args)
|
83 |
+
args = process_train_args(args)
|
84 |
+
validate_train_args(args)
|
85 |
+
|
86 |
+
args.output_dir.mkdir(exist_ok=True, parents=True)
|
87 |
+
config_path = args.output_dir / "config.toml"
|
88 |
+
save_config(cls.parser, args, config_path)
|
89 |
+
main(args)
|
90 |
+
|
91 |
+
|
92 |
+
def add_train_args(parser: ArgumentParser) -> ArgumentParser:
|
93 |
+
parser.add_argument(
|
94 |
+
"--config-path",
|
95 |
+
type=Path,
|
96 |
+
is_config_file=True,
|
97 |
+
help="Path to a configuration file (command line arguments override values in the configuration file)",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"-i",
|
101 |
+
"--data-path",
|
102 |
+
type=Path,
|
103 |
+
help="Path to an input CSV file containing SMILES and the associated target values",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"-o",
|
107 |
+
"--output-dir",
|
108 |
+
"--save-dir",
|
109 |
+
type=Path,
|
110 |
+
help="Directory where training outputs will be saved (defaults to ``CURRENT_DIRECTORY/chemprop_training/STEM_OF_INPUT/TIME_STAMP``)",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--remove-checkpoints",
|
114 |
+
action="store_true",
|
115 |
+
help="Remove intermediate checkpoint files after training is complete.",
|
116 |
+
)
|
117 |
+
|
118 |
+
# TODO: Add in v2.1; see if we can tell lightning how often to log training loss
|
119 |
+
# parser.add_argument(
|
120 |
+
# "--log-frequency",
|
121 |
+
# type=int,
|
122 |
+
# default=10,
|
123 |
+
# help="The number of batches between each logging of the training loss.",
|
124 |
+
# )
|
125 |
+
|
126 |
+
transfer_args = parser.add_argument_group("transfer learning args")
|
127 |
+
transfer_args.add_argument(
|
128 |
+
"--checkpoint",
|
129 |
+
type=Path,
|
130 |
+
nargs="+",
|
131 |
+
help="Path to checkpoint(s) or model file(s) for loading and overwriting weights. Accepts a single pre-trained model checkpoint (.ckpt), a single model file (.pt), a directory containing such files, or a list of paths and directories. If a directory is provided, it will recursively search for and use all (.pt) files found for prediction.",
|
132 |
+
)
|
133 |
+
transfer_args.add_argument(
|
134 |
+
"--freeze-encoder",
|
135 |
+
action="store_true",
|
136 |
+
help="Freeze the message passing layer from the checkpoint model (specified by ``--checkpoint``).",
|
137 |
+
)
|
138 |
+
transfer_args.add_argument(
|
139 |
+
"--model-frzn",
|
140 |
+
help="Path to model checkpoint file to be loaded for overwriting and freezing weights. By default, all MPNN weights are frozen with this option.",
|
141 |
+
)
|
142 |
+
transfer_args.add_argument(
|
143 |
+
"--frzn-ffn-layers",
|
144 |
+
type=int,
|
145 |
+
default=0,
|
146 |
+
help="Freeze the first ``n`` layers of the FFN from the checkpoint model (specified by ``--checkpoint``). The message passing layer should also be frozen with ``--freeze-encoder``.",
|
147 |
+
)
|
148 |
+
# transfer_args.add_argument(
|
149 |
+
# "--freeze-first-only",
|
150 |
+
# action="store_true",
|
151 |
+
# help="Determines whether or not to use checkpoint_frzn for just the first encoder. Default (False) is to use the checkpoint to freeze all encoders. (only relevant for number_of_molecules > 1, where checkpoint model has number_of_molecules = 1)",
|
152 |
+
# )
|
153 |
+
|
154 |
+
# TODO: Add in v2.1
|
155 |
+
# parser.add_argument(
|
156 |
+
# "--resume-experiment",
|
157 |
+
# action="store_true",
|
158 |
+
# help="Whether to resume the experiment. Loads test results from any folds that have already been completed and skips training those folds.",
|
159 |
+
# )
|
160 |
+
# parser.add_argument(
|
161 |
+
# "--config-path",
|
162 |
+
# help="Path to a :code:`.json` file containing arguments. Any arguments present in the config file will override arguments specified via the command line or by the defaults.",
|
163 |
+
# )
|
164 |
+
parser.add_argument(
|
165 |
+
"--ensemble-size",
|
166 |
+
type=int,
|
167 |
+
default=1,
|
168 |
+
help="Number of models in ensemble for each splitting of data",
|
169 |
+
)
|
170 |
+
|
171 |
+
# TODO: Add in v2.2
|
172 |
+
# abt_args = parser.add_argument_group("atom/bond target args")
|
173 |
+
# abt_args.add_argument(
|
174 |
+
# "--is-atom-bond-targets",
|
175 |
+
# action="store_true",
|
176 |
+
# help="Whether this is atomic/bond properties prediction.",
|
177 |
+
# )
|
178 |
+
# abt_args.add_argument(
|
179 |
+
# "--no-adding-bond-types",
|
180 |
+
# action="store_true",
|
181 |
+
# help="Whether the bond types determined by RDKit molecules added to the output of bond targets. This option is intended to be used with the :code:`is_atom_bond_targets`.",
|
182 |
+
# )
|
183 |
+
# abt_args.add_argument(
|
184 |
+
# "--keeping-atom-map",
|
185 |
+
# action="store_true",
|
186 |
+
# help="Whether RDKit molecules keep the original atom mapping. This option is intended to be used when providing atom-mapped SMILES with the :code:`is_atom_bond_targets`.",
|
187 |
+
# )
|
188 |
+
# abt_args.add_argument(
|
189 |
+
# "--no-shared-atom-bond-ffn",
|
190 |
+
# action="store_true",
|
191 |
+
# help="Whether the FFN weights for atom and bond targets should be independent between tasks.",
|
192 |
+
# )
|
193 |
+
# abt_args.add_argument(
|
194 |
+
# "--weights-ffn-num-layers",
|
195 |
+
# type=int,
|
196 |
+
# default=2,
|
197 |
+
# help="Number of layers in FFN for determining weights used in constrained targets.",
|
198 |
+
# )
|
199 |
+
|
200 |
+
mp_args = parser.add_argument_group("message passing")
|
201 |
+
mp_args.add_argument(
|
202 |
+
"--message-hidden-dim", type=int, default=300, help="Hidden dimension of the messages"
|
203 |
+
)
|
204 |
+
mp_args.add_argument(
|
205 |
+
"--message-bias", action="store_true", help="Add bias to the message passing layers"
|
206 |
+
)
|
207 |
+
mp_args.add_argument("--depth", type=int, default=3, help="Number of message passing steps")
|
208 |
+
mp_args.add_argument(
|
209 |
+
"--undirected",
|
210 |
+
action="store_true",
|
211 |
+
help="Pass messages on undirected bonds/edges (always sum the two relevant bond vectors)",
|
212 |
+
)
|
213 |
+
mp_args.add_argument(
|
214 |
+
"--dropout",
|
215 |
+
type=float,
|
216 |
+
default=0.0,
|
217 |
+
help="Dropout probability in message passing/FFN layers",
|
218 |
+
)
|
219 |
+
mp_args.add_argument(
|
220 |
+
"--mpn-shared",
|
221 |
+
action="store_true",
|
222 |
+
help="Whether to use the same message passing neural network for all input molecules (only relevant if ``number_of_molecules`` > 1)",
|
223 |
+
)
|
224 |
+
mp_args.add_argument(
|
225 |
+
"--activation",
|
226 |
+
type=uppercase,
|
227 |
+
default="RELU",
|
228 |
+
choices=list(Activation.keys()),
|
229 |
+
help="Activation function in message passing/FFN layers",
|
230 |
+
)
|
231 |
+
mp_args.add_argument(
|
232 |
+
"--aggregation",
|
233 |
+
"--agg",
|
234 |
+
default="norm",
|
235 |
+
action=LookupAction(AggregationRegistry),
|
236 |
+
help="Aggregation mode to use during graph predictor",
|
237 |
+
)
|
238 |
+
mp_args.add_argument(
|
239 |
+
"--aggregation-norm",
|
240 |
+
type=float,
|
241 |
+
default=100,
|
242 |
+
help="Normalization factor by which to divide summed up atomic features for ``norm`` aggregation",
|
243 |
+
)
|
244 |
+
mp_args.add_argument(
|
245 |
+
"--atom-messages", action="store_true", help="Pass messages on atoms rather than bonds."
|
246 |
+
)
|
247 |
+
|
248 |
+
# TODO: Add in v2.1
|
249 |
+
# mpsolv_args = parser.add_argument_group("message passing with solvent")
|
250 |
+
# mpsolv_args.add_argument(
|
251 |
+
# "--reaction-solvent",
|
252 |
+
# action="store_true",
|
253 |
+
# help="Whether to adjust the MPNN layer to take as input a reaction and a molecule, and to encode them with separate MPNNs.",
|
254 |
+
# )
|
255 |
+
# mpsolv_args.add_argument(
|
256 |
+
# "--bias-solvent",
|
257 |
+
# action="store_true",
|
258 |
+
# help="Whether to add bias to linear layers for solvent MPN if :code:`reaction_solvent` is True.",
|
259 |
+
# )
|
260 |
+
# mpsolv_args.add_argument(
|
261 |
+
# "--hidden-size-solvent",
|
262 |
+
# type=int,
|
263 |
+
# default=300,
|
264 |
+
# help="Dimensionality of hidden layers in solvent MPN if :code:`reaction_solvent` is True.",
|
265 |
+
# )
|
266 |
+
# mpsolv_args.add_argument(
|
267 |
+
# "--depth-solvent",
|
268 |
+
# type=int,
|
269 |
+
# default=3,
|
270 |
+
# help="Number of message passing steps for solvent if :code:`reaction_solvent` is True.",
|
271 |
+
# )
|
272 |
+
|
273 |
+
ffn_args = parser.add_argument_group("FFN args")
|
274 |
+
ffn_args.add_argument(
|
275 |
+
"--ffn-hidden-dim", type=int, default=300, help="Hidden dimension in the FFN top model"
|
276 |
+
)
|
277 |
+
ffn_args.add_argument( # TODO: the default in v1 was 2. (see weights_ffn_num_layers option) Do we really want the default to now be 1?
|
278 |
+
"--ffn-num-layers", type=int, default=1, help="Number of layers in FFN top model"
|
279 |
+
)
|
280 |
+
# TODO: Decide if we want to implment this in v2
|
281 |
+
# ffn_args.add_argument(
|
282 |
+
# "--features-only",
|
283 |
+
# action="store_true",
|
284 |
+
# help="Use only the additional features in an FFN, no graph network.",
|
285 |
+
# )
|
286 |
+
|
287 |
+
extra_mpnn_args = parser.add_argument_group("extra MPNN args")
|
288 |
+
extra_mpnn_args.add_argument(
|
289 |
+
"--batch-norm", action="store_true", help="Turn on batch normalization after aggregation"
|
290 |
+
)
|
291 |
+
extra_mpnn_args.add_argument(
|
292 |
+
"--multiclass-num-classes",
|
293 |
+
type=int,
|
294 |
+
default=3,
|
295 |
+
help="Number of classes when running multiclass classification",
|
296 |
+
)
|
297 |
+
# TODO: Add in v2.1
|
298 |
+
# extra_mpnn_args.add_argument(
|
299 |
+
# "--spectral-activation",
|
300 |
+
# default="exp",
|
301 |
+
# choices=["softplus", "exp"],
|
302 |
+
# help="Indicates which function to use in task_type spectra training to constrain outputs to be positive.",
|
303 |
+
# )
|
304 |
+
|
305 |
+
train_data_args = parser.add_argument_group("training input data args")
|
306 |
+
train_data_args.add_argument(
|
307 |
+
"-w",
|
308 |
+
"--weight-column",
|
309 |
+
help="Name of the column in the input CSV containing individual data weights",
|
310 |
+
)
|
311 |
+
train_data_args.add_argument(
|
312 |
+
"--target-columns",
|
313 |
+
nargs="+",
|
314 |
+
help="Name of the columns containing target values (by default, uses all columns except the SMILES column and the ``ignore_columns``)",
|
315 |
+
)
|
316 |
+
train_data_args.add_argument(
|
317 |
+
"--ignore-columns",
|
318 |
+
nargs="+",
|
319 |
+
help="Name of the columns to ignore when ``target_columns`` is not provided",
|
320 |
+
)
|
321 |
+
train_data_args.add_argument(
|
322 |
+
"--no-cache",
|
323 |
+
action="store_true",
|
324 |
+
help="Turn off caching the featurized ``MolGraph`` s at the beginning of training",
|
325 |
+
)
|
326 |
+
train_data_args.add_argument(
|
327 |
+
"--splits-column",
|
328 |
+
help="Name of the column in the input CSV file containing 'train', 'val', or 'test' for each row.",
|
329 |
+
)
|
330 |
+
# TODO: Add in v2.1
|
331 |
+
# train_data_args.add_argument(
|
332 |
+
# "--spectra-phase-mask-path",
|
333 |
+
# help="Path to a file containing a phase mask array, used for excluding particular regions in spectra predictions.",
|
334 |
+
# )
|
335 |
+
|
336 |
+
train_args = parser.add_argument_group("training args")
|
337 |
+
train_args.add_argument(
|
338 |
+
"-t",
|
339 |
+
"--task-type",
|
340 |
+
default="regression",
|
341 |
+
action=LookupAction(PredictorRegistry),
|
342 |
+
help="Type of dataset (determines the default loss function used during training, defaults to ``regression``)",
|
343 |
+
)
|
344 |
+
train_args.add_argument(
|
345 |
+
"-l",
|
346 |
+
"--loss-function",
|
347 |
+
action=LookupAction(LossFunctionRegistry),
|
348 |
+
help="Loss function to use during training (will use the default loss function for the given task type if not specified)",
|
349 |
+
)
|
350 |
+
train_args.add_argument(
|
351 |
+
"--v-kl",
|
352 |
+
"--evidential-regularization",
|
353 |
+
type=float,
|
354 |
+
default=0.0,
|
355 |
+
help="Specify the value used in regularization for evidential loss function. The default value recommended by Soleimany et al. (2021) is 0.2. However, the optimal value is dataset-dependent, so it is recommended that users test different values to find the best value for their model.",
|
356 |
+
)
|
357 |
+
|
358 |
+
train_args.add_argument(
|
359 |
+
"--eps", type=float, default=1e-8, help="Evidential regularization epsilon"
|
360 |
+
)
|
361 |
+
train_args.add_argument(
|
362 |
+
"--alpha", type=float, default=0.1, help="Target error bounds for quantile interval loss"
|
363 |
+
)
|
364 |
+
# TODO: Add in v2.1
|
365 |
+
# train_args.add_argument( # TODO: Is threshold the same thing as the spectra target floor? I'm not sure but combined them.
|
366 |
+
# "-T",
|
367 |
+
# "--threshold",
|
368 |
+
# "--spectra-target-floor",
|
369 |
+
# type=float,
|
370 |
+
# default=1e-8,
|
371 |
+
# help="spectral threshold limit. v1 help string: Values in targets for dataset type spectra are replaced with this value, intended to be a small positive number used to enforce positive values.",
|
372 |
+
# )
|
373 |
+
train_args.add_argument(
|
374 |
+
"--metrics",
|
375 |
+
"--metric",
|
376 |
+
nargs="+",
|
377 |
+
action=LookupAction(MetricRegistry),
|
378 |
+
help="Specify the evaluation metrics. If unspecified, chemprop will use the following metrics for given dataset types: regression -> ``rmse``, classification -> ``roc``, multiclass -> ``ce`` ('cross entropy'), spectral -> ``sid``. If multiple metrics are provided, the 0-th one will be used for early stopping and checkpointing.",
|
379 |
+
)
|
380 |
+
train_args.add_argument(
|
381 |
+
"--tracking-metric",
|
382 |
+
default="val_loss",
|
383 |
+
help="The metric to track for early stopping and checkpointing. Defaults to the criterion used during training.",
|
384 |
+
)
|
385 |
+
train_args.add_argument(
|
386 |
+
"--show-individual-scores",
|
387 |
+
action="store_true",
|
388 |
+
help="Show all scores for individual targets, not just average, at the end.",
|
389 |
+
)
|
390 |
+
train_args.add_argument(
|
391 |
+
"--task-weights",
|
392 |
+
nargs="+",
|
393 |
+
type=float,
|
394 |
+
help="Weights to apply for whole tasks in the loss function",
|
395 |
+
)
|
396 |
+
train_args.add_argument(
|
397 |
+
"--warmup-epochs",
|
398 |
+
type=int,
|
399 |
+
default=2,
|
400 |
+
help="Number of epochs during which learning rate increases linearly from ``init_lr`` to ``max_lr`` (afterwards, learning rate decreases exponentially from ``max_lr`` to ``final_lr``)",
|
401 |
+
)
|
402 |
+
|
403 |
+
train_args.add_argument("--init-lr", type=float, default=1e-4, help="Initial learning rate.")
|
404 |
+
train_args.add_argument("--max-lr", type=float, default=1e-3, help="Maximum learning rate.")
|
405 |
+
train_args.add_argument("--final-lr", type=float, default=1e-4, help="Final learning rate.")
|
406 |
+
train_args.add_argument("--epochs", type=int, default=50, help="Number of epochs to train over")
|
407 |
+
train_args.add_argument(
|
408 |
+
"--patience",
|
409 |
+
type=int,
|
410 |
+
default=None,
|
411 |
+
help="Number of epochs to wait for improvement before early stopping",
|
412 |
+
)
|
413 |
+
train_args.add_argument(
|
414 |
+
"--grad-clip",
|
415 |
+
type=float,
|
416 |
+
help="Passed directly to the lightning trainer which controls grad clipping (see the ``Trainer()`` docstring for details)",
|
417 |
+
)
|
418 |
+
train_args.add_argument(
|
419 |
+
"--class-balance",
|
420 |
+
action="store_true",
|
421 |
+
help="Ensures each training batch contains an equal number of positive and negative samples.",
|
422 |
+
)
|
423 |
+
|
424 |
+
split_args = parser.add_argument_group("split args")
|
425 |
+
split_args.add_argument(
|
426 |
+
"--split",
|
427 |
+
"--split-type",
|
428 |
+
type=uppercase,
|
429 |
+
default="RANDOM",
|
430 |
+
choices=list(SplitType.keys()),
|
431 |
+
help="Method of splitting the data into train/val/test (case insensitive)",
|
432 |
+
)
|
433 |
+
split_args.add_argument(
|
434 |
+
"--split-sizes",
|
435 |
+
type=float,
|
436 |
+
nargs=3,
|
437 |
+
default=[0.8, 0.1, 0.1],
|
438 |
+
help="Split proportions for train/validation/test sets",
|
439 |
+
)
|
440 |
+
split_args.add_argument(
|
441 |
+
"--split-key-molecule",
|
442 |
+
type=int,
|
443 |
+
default=0,
|
444 |
+
help="Specify the index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used (e.g., ``scaffold_balanced`` or ``random_with_repeated_smiles``). Note that this index begins with zero for the first molecule.",
|
445 |
+
)
|
446 |
+
split_args.add_argument("--num-replicates", type=int, default=1, help="Number of replicates.")
|
447 |
+
split_args.add_argument("-k", "--num-folds", help=_CV_REMOVAL_ERROR)
|
448 |
+
split_args.add_argument(
|
449 |
+
"--save-smiles-splits",
|
450 |
+
action="store_true",
|
451 |
+
help="Whether to store the SMILES in each train/val/test split",
|
452 |
+
)
|
453 |
+
split_args.add_argument(
|
454 |
+
"--splits-file",
|
455 |
+
type=Path,
|
456 |
+
help="Path to a JSON file containing pre-defined splits for the input data, formatted as a list of dictionaries with keys ``train``, ``val``, and ``test`` and values as lists of indices or formatted strings (e.g. [0, 1, 2, 4] or '0-2,4')",
|
457 |
+
)
|
458 |
+
split_args.add_argument(
|
459 |
+
"--data-seed",
|
460 |
+
type=int,
|
461 |
+
default=0,
|
462 |
+
help="Specify the random seed to use when splitting data into train/val/test sets. When ``--num-replicates`` > 1, the first replicate uses this seed and all subsequent replicates add 1 to the seed (also used for shuffling data in ``build_dataloader`` when ``shuffle`` is True).",
|
463 |
+
)
|
464 |
+
|
465 |
+
parser.add_argument(
|
466 |
+
"--pytorch-seed",
|
467 |
+
type=int,
|
468 |
+
default=None,
|
469 |
+
help="Seed for PyTorch randomness (e.g., random initial weights)",
|
470 |
+
)
|
471 |
+
|
472 |
+
return parser
|
473 |
+
|
474 |
+
|
475 |
+
def process_train_args(args: Namespace) -> Namespace:
|
476 |
+
if args.output_dir is None:
|
477 |
+
args.output_dir = CHEMPROP_TRAIN_DIR / args.data_path.stem / NOW
|
478 |
+
|
479 |
+
return args
|
480 |
+
|
481 |
+
|
482 |
+
def validate_train_args(args):
|
483 |
+
if args.config_path is None and args.data_path is None:
|
484 |
+
raise ArgumentError(argument=None, message="Data path must be provided for training.")
|
485 |
+
|
486 |
+
if args.num_folds is not None: # i.e. user-specified
|
487 |
+
raise ArgumentError(argument=None, message=_CV_REMOVAL_ERROR)
|
488 |
+
|
489 |
+
if args.data_path.suffix not in [".csv"]:
|
490 |
+
raise ArgumentError(
|
491 |
+
argument=None, message=f"Input data must be a CSV file. Got {args.data_path}"
|
492 |
+
)
|
493 |
+
|
494 |
+
if args.epochs != -1 and args.epochs <= args.warmup_epochs:
|
495 |
+
raise ArgumentError(
|
496 |
+
argument=None,
|
497 |
+
message=f"The number of epochs should be higher than the number of epochs during warmup. Got {args.epochs} epochs and {args.warmup_epochs} warmup epochs",
|
498 |
+
)
|
499 |
+
|
500 |
+
# TODO: model_frzn is deprecated and then remove in v2.2
|
501 |
+
if args.checkpoint is not None and args.model_frzn is not None:
|
502 |
+
raise ArgumentError(
|
503 |
+
argument=None,
|
504 |
+
message="`--checkpoint` and `--model-frzn` cannot be used at the same time.",
|
505 |
+
)
|
506 |
+
|
507 |
+
if "--model-frzn" in sys.argv:
|
508 |
+
logger.warning(
|
509 |
+
"`--model-frzn` is deprecated and will be removed in v2.2. "
|
510 |
+
"Please use `--checkpoint` with `--freeze-encoder` instead."
|
511 |
+
)
|
512 |
+
|
513 |
+
if args.freeze_encoder and args.checkpoint is None:
|
514 |
+
raise ArgumentError(
|
515 |
+
argument=None,
|
516 |
+
message="`--freeze-encoder` can only be used when `--checkpoint` is used.",
|
517 |
+
)
|
518 |
+
|
519 |
+
if args.frzn_ffn_layers > 0:
|
520 |
+
if args.checkpoint is None and args.model_frzn is None:
|
521 |
+
raise ArgumentError(
|
522 |
+
argument=None,
|
523 |
+
message="`--frzn-ffn-layers` can only be used when `--checkpoint` or `--model-frzn` (depreciated in v2.1) is used.",
|
524 |
+
)
|
525 |
+
if args.checkpoint is not None and not args.freeze_encoder:
|
526 |
+
raise ArgumentError(
|
527 |
+
argument=None,
|
528 |
+
message="To freeze the first `n` layers of the FFN via `--frzn-ffn-layers`. The message passing layer should also be frozen with `--freeze-encoder`.",
|
529 |
+
)
|
530 |
+
|
531 |
+
if args.class_balance and args.task_type != "classification":
|
532 |
+
raise ArgumentError(
|
533 |
+
argument=None, message="Class balance is only applicable for classification tasks."
|
534 |
+
)
|
535 |
+
|
536 |
+
valid_tracking_metrics = (
|
537 |
+
args.metrics or [PredictorRegistry[args.task_type]._T_default_metric.alias]
|
538 |
+
) + ["val_loss"]
|
539 |
+
if args.tracking_metric not in valid_tracking_metrics:
|
540 |
+
raise ArgumentError(
|
541 |
+
argument=None,
|
542 |
+
message=f"Tracking metric must be one of {','.join(valid_tracking_metrics)}. "
|
543 |
+
f"Got {args.tracking_metric}. Additional tracking metric options can be specified with "
|
544 |
+
"the `--metrics` flag.",
|
545 |
+
)
|
546 |
+
|
547 |
+
input_cols, target_cols = get_column_names(
|
548 |
+
args.data_path,
|
549 |
+
args.smiles_columns,
|
550 |
+
args.reaction_columns,
|
551 |
+
args.target_columns,
|
552 |
+
args.ignore_columns,
|
553 |
+
args.splits_column,
|
554 |
+
args.weight_column,
|
555 |
+
args.no_header_row,
|
556 |
+
)
|
557 |
+
|
558 |
+
args.input_columns = input_cols
|
559 |
+
args.target_columns = target_cols
|
560 |
+
|
561 |
+
return args
|
562 |
+
|
563 |
+
|
564 |
+
def normalize_inputs(train_dset, val_dset, args):
|
565 |
+
multicomponent = isinstance(train_dset, MulticomponentDataset)
|
566 |
+
num_components = train_dset.n_components if multicomponent else 1
|
567 |
+
|
568 |
+
X_d_transform = None
|
569 |
+
V_f_transforms = [nn.Identity()] * num_components
|
570 |
+
E_f_transforms = [nn.Identity()] * num_components
|
571 |
+
V_d_transforms = [None] * num_components
|
572 |
+
graph_transforms = []
|
573 |
+
|
574 |
+
d_xd = train_dset.d_xd
|
575 |
+
d_vf = train_dset.d_vf
|
576 |
+
d_ef = train_dset.d_ef
|
577 |
+
d_vd = train_dset.d_vd
|
578 |
+
|
579 |
+
if d_xd > 0 and not args.no_descriptor_scaling:
|
580 |
+
scaler = train_dset.normalize_inputs("X_d")
|
581 |
+
val_dset.normalize_inputs("X_d", scaler)
|
582 |
+
|
583 |
+
scaler = scaler if not isinstance(scaler, list) else scaler[0]
|
584 |
+
|
585 |
+
if scaler is not None:
|
586 |
+
logger.info(
|
587 |
+
f"Descriptors: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
588 |
+
)
|
589 |
+
X_d_transform = ScaleTransform.from_standard_scaler(scaler)
|
590 |
+
|
591 |
+
if d_vf > 0 and not args.no_atom_feature_scaling:
|
592 |
+
scaler = train_dset.normalize_inputs("V_f")
|
593 |
+
val_dset.normalize_inputs("V_f", scaler)
|
594 |
+
|
595 |
+
scalers = [scaler] if not isinstance(scaler, list) else scaler
|
596 |
+
|
597 |
+
for i, scaler in enumerate(scalers):
|
598 |
+
if scaler is None:
|
599 |
+
continue
|
600 |
+
|
601 |
+
logger.info(
|
602 |
+
f"Atom features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
603 |
+
)
|
604 |
+
featurizer = (
|
605 |
+
train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
|
606 |
+
)
|
607 |
+
V_f_transforms[i] = ScaleTransform.from_standard_scaler(
|
608 |
+
scaler, pad=featurizer.atom_fdim - featurizer.extra_atom_fdim
|
609 |
+
)
|
610 |
+
|
611 |
+
if d_ef > 0 and not args.no_bond_feature_scaling:
|
612 |
+
scaler = train_dset.normalize_inputs("E_f")
|
613 |
+
val_dset.normalize_inputs("E_f", scaler)
|
614 |
+
|
615 |
+
scalers = [scaler] if not isinstance(scaler, list) else scaler
|
616 |
+
|
617 |
+
for i, scaler in enumerate(scalers):
|
618 |
+
if scaler is None:
|
619 |
+
continue
|
620 |
+
|
621 |
+
logger.info(
|
622 |
+
f"Bond features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
623 |
+
)
|
624 |
+
featurizer = (
|
625 |
+
train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
|
626 |
+
)
|
627 |
+
E_f_transforms[i] = ScaleTransform.from_standard_scaler(
|
628 |
+
scaler, pad=featurizer.bond_fdim - featurizer.extra_bond_fdim
|
629 |
+
)
|
630 |
+
|
631 |
+
for V_f_transform, E_f_transform in zip(V_f_transforms, E_f_transforms):
|
632 |
+
graph_transforms.append(GraphTransform(V_f_transform, E_f_transform))
|
633 |
+
|
634 |
+
if d_vd > 0 and not args.no_atom_descriptor_scaling:
|
635 |
+
scaler = train_dset.normalize_inputs("V_d")
|
636 |
+
val_dset.normalize_inputs("V_d", scaler)
|
637 |
+
|
638 |
+
scalers = [scaler] if not isinstance(scaler, list) else scaler
|
639 |
+
|
640 |
+
for i, scaler in enumerate(scalers):
|
641 |
+
if scaler is None:
|
642 |
+
continue
|
643 |
+
|
644 |
+
logger.info(
|
645 |
+
f"Atom descriptors for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
646 |
+
)
|
647 |
+
V_d_transforms[i] = ScaleTransform.from_standard_scaler(scaler)
|
648 |
+
|
649 |
+
return X_d_transform, graph_transforms, V_d_transforms
|
650 |
+
|
651 |
+
|
652 |
+
def load_and_use_pretrained_model_scalers(model_path: Path, train_dset, val_dset) -> None:
|
653 |
+
if isinstance(train_dset, MulticomponentDataset):
|
654 |
+
_model = MulticomponentMPNN.load_from_file(model_path)
|
655 |
+
blocks = _model.message_passing.blocks
|
656 |
+
train_dsets = train_dset.datasets
|
657 |
+
val_dsets = val_dset.datasets
|
658 |
+
else:
|
659 |
+
_model = MPNN.load_from_file(model_path)
|
660 |
+
blocks = [_model.message_passing]
|
661 |
+
train_dsets = [train_dset]
|
662 |
+
val_dsets = [val_dset]
|
663 |
+
|
664 |
+
for i in range(len(blocks)):
|
665 |
+
if isinstance(_model.X_d_transform, ScaleTransform):
|
666 |
+
scaler = _model.X_d_transform.to_standard_scaler()
|
667 |
+
train_dsets[i].normalize_inputs("X_d", scaler)
|
668 |
+
val_dsets[i].normalize_inputs("X_d", scaler)
|
669 |
+
|
670 |
+
if isinstance(blocks[i].graph_transform, GraphTransform):
|
671 |
+
if isinstance(blocks[i].graph_transform.V_transform, ScaleTransform):
|
672 |
+
V_anti_pad = (
|
673 |
+
train_dsets[i].featurizer.atom_fdim - train_dsets[i].featurizer.extra_atom_fdim
|
674 |
+
)
|
675 |
+
scaler = blocks[i].graph_transform.V_transform.to_standard_scaler(
|
676 |
+
anti_pad=V_anti_pad
|
677 |
+
)
|
678 |
+
train_dsets[i].normalize_inputs("V_f", scaler)
|
679 |
+
val_dsets[i].normalize_inputs("V_f", scaler)
|
680 |
+
if isinstance(blocks[i].graph_transform.E_transform, ScaleTransform):
|
681 |
+
E_anti_pad = (
|
682 |
+
train_dsets[i].featurizer.bond_fdim - train_dsets[i].featurizer.extra_bond_fdim
|
683 |
+
)
|
684 |
+
scaler = blocks[i].graph_transform.E_transform.to_standard_scaler(
|
685 |
+
anti_pad=E_anti_pad
|
686 |
+
)
|
687 |
+
train_dsets[i].normalize_inputs("E_f", scaler)
|
688 |
+
val_dsets[i].normalize_inputs("E_f", scaler)
|
689 |
+
|
690 |
+
if isinstance(blocks[i].V_d_transform, ScaleTransform):
|
691 |
+
scaler = blocks[i].V_d_transform.to_standard_scaler()
|
692 |
+
train_dsets[i].normalize_inputs("V_d", scaler)
|
693 |
+
val_dsets[i].normalize_inputs("V_d", scaler)
|
694 |
+
|
695 |
+
if isinstance(_model.predictor.output_transform, UnscaleTransform):
|
696 |
+
scaler = _model.predictor.output_transform.to_standard_scaler()
|
697 |
+
train_dset.normalize_targets(scaler)
|
698 |
+
val_dset.normalize_targets(scaler)
|
699 |
+
|
700 |
+
|
701 |
+
def save_config(parser: ArgumentParser, args: Namespace, config_path: Path):
|
702 |
+
config_args = deepcopy(args)
|
703 |
+
for key, value in vars(config_args).items():
|
704 |
+
if isinstance(value, Path):
|
705 |
+
setattr(config_args, key, str(value))
|
706 |
+
|
707 |
+
for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
|
708 |
+
if getattr(config_args, key) is not None:
|
709 |
+
for index, path in getattr(config_args, key).items():
|
710 |
+
getattr(config_args, key)[index] = str(path)
|
711 |
+
|
712 |
+
parser.write_config_file(parsed_namespace=config_args, output_file_paths=[str(config_path)])
|
713 |
+
|
714 |
+
|
715 |
+
def save_smiles_splits(args: Namespace, output_dir, train_dset, val_dset, test_dset):
|
716 |
+
match (args.smiles_columns, args.reaction_columns):
|
717 |
+
case [_, None]:
|
718 |
+
column_labels = deepcopy(args.smiles_columns)
|
719 |
+
case [None, _]:
|
720 |
+
column_labels = deepcopy(args.reaction_columns)
|
721 |
+
case _:
|
722 |
+
column_labels = deepcopy(args.smiles_columns)
|
723 |
+
column_labels.extend(args.reaction_columns)
|
724 |
+
|
725 |
+
train_smis = train_dset.names
|
726 |
+
df_train = pd.DataFrame(train_smis, columns=column_labels)
|
727 |
+
df_train.to_csv(output_dir / "train_smiles.csv", index=False)
|
728 |
+
|
729 |
+
val_smis = val_dset.names
|
730 |
+
df_val = pd.DataFrame(val_smis, columns=column_labels)
|
731 |
+
df_val.to_csv(output_dir / "val_smiles.csv", index=False)
|
732 |
+
|
733 |
+
if test_dset is not None:
|
734 |
+
test_smis = test_dset.names
|
735 |
+
df_test = pd.DataFrame(test_smis, columns=column_labels)
|
736 |
+
df_test.to_csv(output_dir / "test_smiles.csv", index=False)
|
737 |
+
|
738 |
+
|
739 |
+
def build_splits(args, format_kwargs, featurization_kwargs):
|
740 |
+
"""build the train/val/test splits"""
|
741 |
+
logger.info(f"Pulling data from file: {args.data_path}")
|
742 |
+
all_data = build_data_from_files(
|
743 |
+
args.data_path,
|
744 |
+
p_descriptors=args.descriptors_path,
|
745 |
+
p_atom_feats=args.atom_features_path,
|
746 |
+
p_bond_feats=args.bond_features_path,
|
747 |
+
p_atom_descs=args.atom_descriptors_path,
|
748 |
+
**format_kwargs,
|
749 |
+
**featurization_kwargs,
|
750 |
+
)
|
751 |
+
|
752 |
+
if args.splits_column is not None:
|
753 |
+
df = pd.read_csv(
|
754 |
+
args.data_path, header=None if args.no_header_row else "infer", index_col=False
|
755 |
+
)
|
756 |
+
grouped = df.groupby(df[args.splits_column].str.lower())
|
757 |
+
train_indices = grouped.groups.get("train", pd.Index([])).tolist()
|
758 |
+
val_indices = grouped.groups.get("val", pd.Index([])).tolist()
|
759 |
+
test_indices = grouped.groups.get("test", pd.Index([])).tolist()
|
760 |
+
train_indices, val_indices, test_indices = [train_indices], [val_indices], [test_indices]
|
761 |
+
|
762 |
+
elif args.splits_file is not None:
|
763 |
+
with open(args.splits_file, "rb") as json_file:
|
764 |
+
split_idxss = json.load(json_file)
|
765 |
+
train_indices = [parse_indices(d["train"]) for d in split_idxss]
|
766 |
+
val_indices = [parse_indices(d["val"]) for d in split_idxss]
|
767 |
+
test_indices = [parse_indices(d["test"]) for d in split_idxss]
|
768 |
+
args.num_replicates = len(split_idxss)
|
769 |
+
|
770 |
+
else:
|
771 |
+
splitting_data = all_data[args.split_key_molecule]
|
772 |
+
if isinstance(splitting_data[0], ReactionDatapoint):
|
773 |
+
splitting_mols = [datapoint.rct for datapoint in splitting_data]
|
774 |
+
else:
|
775 |
+
splitting_mols = [datapoint.mol for datapoint in splitting_data]
|
776 |
+
train_indices, val_indices, test_indices = make_split_indices(
|
777 |
+
splitting_mols, args.split, args.split_sizes, args.data_seed, args.num_replicates
|
778 |
+
)
|
779 |
+
|
780 |
+
train_data, val_data, test_data = split_data_by_indices(
|
781 |
+
all_data, train_indices, val_indices, test_indices
|
782 |
+
)
|
783 |
+
for i_split in range(len(train_data)):
|
784 |
+
sizes = [len(train_data[i_split][0]), len(val_data[i_split][0]), len(test_data[i_split][0])]
|
785 |
+
logger.info(f"train/val/test split_{i_split} sizes: {sizes}")
|
786 |
+
|
787 |
+
return train_data, val_data, test_data
|
788 |
+
|
789 |
+
|
790 |
+
def summarize(
|
791 |
+
target_cols: list[str], task_type: str, dataset: _MolGraphDatasetMixin
|
792 |
+
) -> tuple[list, list]:
|
793 |
+
if task_type in [
|
794 |
+
"regression",
|
795 |
+
"regression-mve",
|
796 |
+
"regression-evidential",
|
797 |
+
"regression-quantile",
|
798 |
+
]:
|
799 |
+
if isinstance(dataset, MulticomponentDataset):
|
800 |
+
y = dataset.datasets[0].Y
|
801 |
+
else:
|
802 |
+
y = dataset.Y
|
803 |
+
y_mean = np.nanmean(y, axis=0)
|
804 |
+
y_std = np.nanstd(y, axis=0)
|
805 |
+
y_median = np.nanmedian(y, axis=0)
|
806 |
+
mean_dev_abs = np.abs(y - y_mean)
|
807 |
+
num_targets = np.sum(~np.isnan(y), axis=0)
|
808 |
+
frac_1_sigma = np.sum((mean_dev_abs < y_std), axis=0) / num_targets
|
809 |
+
frac_2_sigma = np.sum((mean_dev_abs < 2 * y_std), axis=0) / num_targets
|
810 |
+
|
811 |
+
column_headers = ["Statistic"] + [f"Value ({target_cols[i]})" for i in range(y.shape[1])]
|
812 |
+
table_rows = [
|
813 |
+
["Num. smiles"] + [f"{len(y)}" for i in range(y.shape[1])],
|
814 |
+
["Num. targets"] + [f"{num_targets[i]}" for i in range(y.shape[1])],
|
815 |
+
["Num. NaN"] + [f"{len(y) - num_targets[i]}" for i in range(y.shape[1])],
|
816 |
+
["Mean"] + [f"{mean:0.3g}" for mean in y_mean],
|
817 |
+
["Std. dev."] + [f"{std:0.3g}" for std in y_std],
|
818 |
+
["Median"] + [f"{median:0.3g}" for median in y_median],
|
819 |
+
["% within 1 s.d."] + [f"{sigma:0.0%}" for sigma in frac_1_sigma],
|
820 |
+
["% within 2 s.d."] + [f"{sigma:0.0%}" for sigma in frac_2_sigma],
|
821 |
+
]
|
822 |
+
return (column_headers, table_rows)
|
823 |
+
elif task_type in [
|
824 |
+
"classification",
|
825 |
+
"classification-dirichlet",
|
826 |
+
"multiclass",
|
827 |
+
"multiclass-dirichlet",
|
828 |
+
]:
|
829 |
+
if isinstance(dataset, MulticomponentDataset):
|
830 |
+
y = dataset.datasets[0].Y
|
831 |
+
else:
|
832 |
+
y = dataset.Y
|
833 |
+
|
834 |
+
mask = np.isnan(y)
|
835 |
+
classes = np.sort(np.unique(y[~mask]))
|
836 |
+
|
837 |
+
class_counts = np.stack([(classes[:, None] == y[:, i]).sum(1) for i in range(y.shape[1])])
|
838 |
+
class_fracs = class_counts / y.shape[0]
|
839 |
+
nan_count = np.nansum(mask, axis=0)
|
840 |
+
nan_frac = nan_count / y.shape[0]
|
841 |
+
|
842 |
+
column_headers = ["Class"] + [f"Count/Percent {target_cols[i]}" for i in range(y.shape[1])]
|
843 |
+
|
844 |
+
table_rows = [
|
845 |
+
[f"{k}"] + [f"{class_counts[j, i]}/{class_fracs[j, i]:0.0%}" for j in range(y.shape[1])]
|
846 |
+
for i, k in enumerate(classes)
|
847 |
+
]
|
848 |
+
|
849 |
+
nan_row = ["NaN"] + [f"{nan_count[i]}/{nan_frac[i]:0.0%}" for i in range(y.shape[1])]
|
850 |
+
table_rows.append(nan_row)
|
851 |
+
|
852 |
+
total_row = ["Total"] + [f"{y.shape[0]}/{100.00}%" for i in range(y.shape[1])]
|
853 |
+
table_rows.append(total_row)
|
854 |
+
|
855 |
+
return (column_headers, table_rows)
|
856 |
+
else:
|
857 |
+
raise ValueError(f"unsupported task type! Task type '{task_type}' was not recognized.")
|
858 |
+
|
859 |
+
|
860 |
+
def build_table(column_headers: list[str], table_rows: list[str], title: str | None = None) -> str:
|
861 |
+
right_justified_columns = [
|
862 |
+
Column(header=column_header, justify="right") for column_header in column_headers
|
863 |
+
]
|
864 |
+
table = Table(*right_justified_columns, title=title)
|
865 |
+
for row in table_rows:
|
866 |
+
table.add_row(*row)
|
867 |
+
|
868 |
+
console = Console(record=True, file=StringIO(), width=200)
|
869 |
+
console.print(table)
|
870 |
+
return console.export_text()
|
871 |
+
|
872 |
+
|
873 |
+
def build_datasets(args, train_data, val_data, test_data):
|
874 |
+
"""build the train/val/test datasets, where :attr:`test_data` may be None"""
|
875 |
+
multicomponent = len(train_data) > 1
|
876 |
+
if multicomponent:
|
877 |
+
train_dsets = [
|
878 |
+
make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
879 |
+
for data in train_data
|
880 |
+
]
|
881 |
+
val_dsets = [
|
882 |
+
make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
883 |
+
for data in val_data
|
884 |
+
]
|
885 |
+
train_dset = MulticomponentDataset(train_dsets)
|
886 |
+
val_dset = MulticomponentDataset(val_dsets)
|
887 |
+
if len(test_data[0]) > 0:
|
888 |
+
test_dsets = [
|
889 |
+
make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
890 |
+
for data in test_data
|
891 |
+
]
|
892 |
+
test_dset = MulticomponentDataset(test_dsets)
|
893 |
+
else:
|
894 |
+
test_dset = None
|
895 |
+
else:
|
896 |
+
train_data = train_data[0]
|
897 |
+
val_data = val_data[0]
|
898 |
+
test_data = test_data[0]
|
899 |
+
train_dset = make_dataset(train_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
900 |
+
val_dset = make_dataset(val_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
901 |
+
if len(test_data) > 0:
|
902 |
+
test_dset = make_dataset(test_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
903 |
+
else:
|
904 |
+
test_dset = None
|
905 |
+
if args.task_type != "spectral":
|
906 |
+
for dataset, label in zip(
|
907 |
+
[train_dset, val_dset, test_dset], ["Training", "Validation", "Test"]
|
908 |
+
):
|
909 |
+
column_headers, table_rows = summarize(args.target_columns, args.task_type, dataset)
|
910 |
+
output = build_table(column_headers, table_rows, f"Summary of {label} Data")
|
911 |
+
logger.info("\n" + output)
|
912 |
+
|
913 |
+
return train_dset, val_dset, test_dset
|
914 |
+
|
915 |
+
|
916 |
+
def build_model(
|
917 |
+
args,
|
918 |
+
train_dset: MolGraphDataset | MulticomponentDataset,
|
919 |
+
output_transform: UnscaleTransform,
|
920 |
+
input_transforms: tuple[ScaleTransform, list[GraphTransform], list[ScaleTransform]],
|
921 |
+
) -> MPNN:
|
922 |
+
mp_cls = AtomMessagePassing if args.atom_messages else BondMessagePassing
|
923 |
+
|
924 |
+
X_d_transform, graph_transforms, V_d_transforms = input_transforms
|
925 |
+
if isinstance(train_dset, MulticomponentDataset):
|
926 |
+
mp_blocks = [
|
927 |
+
mp_cls(
|
928 |
+
train_dset.datasets[i].featurizer.atom_fdim,
|
929 |
+
train_dset.datasets[i].featurizer.bond_fdim,
|
930 |
+
d_h=args.message_hidden_dim,
|
931 |
+
d_vd=(
|
932 |
+
train_dset.datasets[i].d_vd
|
933 |
+
if isinstance(train_dset.datasets[i], MoleculeDataset)
|
934 |
+
else 0
|
935 |
+
),
|
936 |
+
bias=args.message_bias,
|
937 |
+
depth=args.depth,
|
938 |
+
undirected=args.undirected,
|
939 |
+
dropout=args.dropout,
|
940 |
+
activation=args.activation,
|
941 |
+
V_d_transform=V_d_transforms[i],
|
942 |
+
graph_transform=graph_transforms[i],
|
943 |
+
)
|
944 |
+
for i in range(train_dset.n_components)
|
945 |
+
]
|
946 |
+
if args.mpn_shared:
|
947 |
+
if args.reaction_columns is not None and args.smiles_columns is not None:
|
948 |
+
raise ArgumentError(
|
949 |
+
argument=None,
|
950 |
+
message="Cannot use shared MPNN with both molecule and reaction data.",
|
951 |
+
)
|
952 |
+
|
953 |
+
mp_block = MulticomponentMessagePassing(mp_blocks, train_dset.n_components, args.mpn_shared)
|
954 |
+
# NOTE(degraff): this if/else block should be handled by the init of MulticomponentMessagePassing
|
955 |
+
# if args.mpn_shared:
|
956 |
+
# mp_block = MulticomponentMessagePassing(mp_blocks[0], n_components, args.mpn_shared)
|
957 |
+
# else:
|
958 |
+
d_xd = train_dset.datasets[0].d_xd
|
959 |
+
n_tasks = train_dset.datasets[0].Y.shape[1]
|
960 |
+
mpnn_cls = MulticomponentMPNN
|
961 |
+
else:
|
962 |
+
mp_block = mp_cls(
|
963 |
+
train_dset.featurizer.atom_fdim,
|
964 |
+
train_dset.featurizer.bond_fdim,
|
965 |
+
d_h=args.message_hidden_dim,
|
966 |
+
d_vd=train_dset.d_vd if isinstance(train_dset, MoleculeDataset) else 0,
|
967 |
+
bias=args.message_bias,
|
968 |
+
depth=args.depth,
|
969 |
+
undirected=args.undirected,
|
970 |
+
dropout=args.dropout,
|
971 |
+
activation=args.activation,
|
972 |
+
V_d_transform=V_d_transforms[0],
|
973 |
+
graph_transform=graph_transforms[0],
|
974 |
+
)
|
975 |
+
d_xd = train_dset.d_xd
|
976 |
+
n_tasks = train_dset.Y.shape[1]
|
977 |
+
mpnn_cls = MPNN
|
978 |
+
|
979 |
+
agg = Factory.build(AggregationRegistry[args.aggregation], norm=args.aggregation_norm)
|
980 |
+
predictor_cls = PredictorRegistry[args.task_type]
|
981 |
+
if args.loss_function is not None:
|
982 |
+
task_weights = torch.ones(n_tasks) if args.task_weights is None else args.task_weights
|
983 |
+
criterion = Factory.build(
|
984 |
+
LossFunctionRegistry[args.loss_function],
|
985 |
+
task_weights=task_weights,
|
986 |
+
v_kl=args.v_kl,
|
987 |
+
# threshold=args.threshold, TODO: Add in v2.1
|
988 |
+
eps=args.eps,
|
989 |
+
alpha=args.alpha,
|
990 |
+
)
|
991 |
+
else:
|
992 |
+
criterion = None
|
993 |
+
if args.metrics is not None:
|
994 |
+
metrics = [Factory.build(MetricRegistry[metric]) for metric in args.metrics]
|
995 |
+
else:
|
996 |
+
metrics = None
|
997 |
+
|
998 |
+
predictor = Factory.build(
|
999 |
+
predictor_cls,
|
1000 |
+
input_dim=mp_block.output_dim + d_xd,
|
1001 |
+
n_tasks=n_tasks,
|
1002 |
+
hidden_dim=args.ffn_hidden_dim,
|
1003 |
+
n_layers=args.ffn_num_layers,
|
1004 |
+
dropout=args.dropout,
|
1005 |
+
activation=args.activation,
|
1006 |
+
criterion=criterion,
|
1007 |
+
task_weights=args.task_weights,
|
1008 |
+
n_classes=args.multiclass_num_classes,
|
1009 |
+
output_transform=output_transform,
|
1010 |
+
# spectral_activation=args.spectral_activation, TODO: Add in v2.1
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
if args.loss_function is None:
|
1014 |
+
logger.info(
|
1015 |
+
f"No loss function was specified! Using class default: {predictor_cls._T_default_criterion}"
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
return mpnn_cls(
|
1019 |
+
mp_block,
|
1020 |
+
agg,
|
1021 |
+
predictor,
|
1022 |
+
args.batch_norm,
|
1023 |
+
metrics,
|
1024 |
+
args.warmup_epochs,
|
1025 |
+
args.init_lr,
|
1026 |
+
args.max_lr,
|
1027 |
+
args.final_lr,
|
1028 |
+
X_d_transform=X_d_transform,
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
|
1032 |
+
def train_model(
|
1033 |
+
args, train_loader, val_loader, test_loader, output_dir, output_transform, input_transforms
|
1034 |
+
):
|
1035 |
+
if args.checkpoint is not None:
|
1036 |
+
model_paths = find_models(args.checkpoint)
|
1037 |
+
if args.ensemble_size != len(model_paths):
|
1038 |
+
logger.warning(
|
1039 |
+
f"The number of models in ensemble for each splitting of data is set to {len(model_paths)}."
|
1040 |
+
)
|
1041 |
+
args.ensemble_size = len(model_paths)
|
1042 |
+
|
1043 |
+
for model_idx in range(args.ensemble_size):
|
1044 |
+
model_output_dir = output_dir / f"model_{model_idx}"
|
1045 |
+
model_output_dir.mkdir(exist_ok=True, parents=True)
|
1046 |
+
|
1047 |
+
if args.pytorch_seed is None:
|
1048 |
+
seed = torch.seed()
|
1049 |
+
deterministic = False
|
1050 |
+
else:
|
1051 |
+
seed = args.pytorch_seed + model_idx
|
1052 |
+
deterministic = True
|
1053 |
+
|
1054 |
+
torch.manual_seed(seed)
|
1055 |
+
|
1056 |
+
if args.checkpoint or args.model_frzn is not None:
|
1057 |
+
mpnn_cls = (
|
1058 |
+
MulticomponentMPNN
|
1059 |
+
if isinstance(train_loader.dataset, MulticomponentDataset)
|
1060 |
+
else MPNN
|
1061 |
+
)
|
1062 |
+
model_path = model_paths[model_idx] if args.checkpoint else args.model_frzn
|
1063 |
+
model = mpnn_cls.load_from_file(model_path)
|
1064 |
+
|
1065 |
+
if args.checkpoint:
|
1066 |
+
model.apply(
|
1067 |
+
lambda m: setattr(m, "p", args.dropout)
|
1068 |
+
if isinstance(m, torch.nn.Dropout)
|
1069 |
+
else None
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
# TODO: model_frzn is deprecated and then remove in v2.2
|
1073 |
+
if args.model_frzn or args.freeze_encoder:
|
1074 |
+
model.message_passing.apply(lambda module: module.requires_grad_(False))
|
1075 |
+
model.message_passing.eval()
|
1076 |
+
model.bn.apply(lambda module: module.requires_grad_(False))
|
1077 |
+
model.bn.eval()
|
1078 |
+
for idx in range(args.frzn_ffn_layers):
|
1079 |
+
model.predictor.ffn[idx].requires_grad_(False)
|
1080 |
+
model.predictor.ffn[idx + 1].eval()
|
1081 |
+
else:
|
1082 |
+
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
|
1083 |
+
logger.info(model)
|
1084 |
+
|
1085 |
+
try:
|
1086 |
+
trainer_logger = TensorBoardLogger(
|
1087 |
+
model_output_dir, "trainer_logs", default_hp_metric=False
|
1088 |
+
)
|
1089 |
+
except ModuleNotFoundError as e:
|
1090 |
+
logger.warning(
|
1091 |
+
f"Unable to import TensorBoardLogger, reverting to CSVLogger (original error: {e})."
|
1092 |
+
)
|
1093 |
+
trainer_logger = CSVLogger(model_output_dir, "trainer_logs")
|
1094 |
+
|
1095 |
+
if args.tracking_metric == "val_loss":
|
1096 |
+
T_tracking_metric = model.criterion.__class__
|
1097 |
+
tracking_metric = args.tracking_metric
|
1098 |
+
else:
|
1099 |
+
T_tracking_metric = MetricRegistry[args.tracking_metric]
|
1100 |
+
tracking_metric = "val/" + args.tracking_metric
|
1101 |
+
|
1102 |
+
monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
|
1103 |
+
logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
|
1104 |
+
|
1105 |
+
if args.remove_checkpoints:
|
1106 |
+
temp_dir = TemporaryDirectory()
|
1107 |
+
checkpoint_dir = Path(temp_dir.name)
|
1108 |
+
else:
|
1109 |
+
checkpoint_dir = model_output_dir
|
1110 |
+
|
1111 |
+
checkpoint_filename = (
|
1112 |
+
f"best-epoch={{epoch}}-{tracking_metric.replace('/', '_')}="
|
1113 |
+
f"{{{tracking_metric}:.2f}}"
|
1114 |
+
)
|
1115 |
+
checkpointing = ModelCheckpoint(
|
1116 |
+
checkpoint_dir / "checkpoints",
|
1117 |
+
checkpoint_filename,
|
1118 |
+
tracking_metric,
|
1119 |
+
mode=monitor_mode,
|
1120 |
+
save_last=True,
|
1121 |
+
auto_insert_metric_name=False,
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
if args.epochs != -1:
|
1125 |
+
patience = args.patience if args.patience is not None else args.epochs
|
1126 |
+
early_stopping = EarlyStopping(tracking_metric, patience=patience, mode=monitor_mode)
|
1127 |
+
callbacks = [checkpointing, early_stopping]
|
1128 |
+
else:
|
1129 |
+
callbacks = [checkpointing]
|
1130 |
+
|
1131 |
+
trainer = pl.Trainer(
|
1132 |
+
logger=trainer_logger,
|
1133 |
+
enable_progress_bar=True,
|
1134 |
+
accelerator=args.accelerator,
|
1135 |
+
devices=args.devices,
|
1136 |
+
max_epochs=args.epochs,
|
1137 |
+
callbacks=callbacks,
|
1138 |
+
gradient_clip_val=args.grad_clip,
|
1139 |
+
deterministic=deterministic,
|
1140 |
+
)
|
1141 |
+
trainer.fit(model, train_loader, val_loader)
|
1142 |
+
|
1143 |
+
if test_loader is not None:
|
1144 |
+
if isinstance(trainer.strategy, DDPStrategy):
|
1145 |
+
torch.distributed.destroy_process_group()
|
1146 |
+
|
1147 |
+
best_ckpt_path = trainer.checkpoint_callback.best_model_path
|
1148 |
+
trainer = pl.Trainer(
|
1149 |
+
logger=trainer_logger,
|
1150 |
+
enable_progress_bar=True,
|
1151 |
+
accelerator=args.accelerator,
|
1152 |
+
devices=1,
|
1153 |
+
)
|
1154 |
+
model = model.load_from_checkpoint(best_ckpt_path)
|
1155 |
+
predss = trainer.predict(model, dataloaders=test_loader)
|
1156 |
+
else:
|
1157 |
+
predss = trainer.predict(dataloaders=test_loader)
|
1158 |
+
|
1159 |
+
preds = torch.concat(predss, 0)
|
1160 |
+
if model.predictor.n_targets > 1:
|
1161 |
+
preds = preds[..., 0]
|
1162 |
+
preds = preds.numpy()
|
1163 |
+
|
1164 |
+
evaluate_and_save_predictions(
|
1165 |
+
preds, test_loader, model.metrics[:-1], model_output_dir, args
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
best_model_path = checkpointing.best_model_path
|
1169 |
+
model = model.__class__.load_from_checkpoint(best_model_path)
|
1170 |
+
p_model = model_output_dir / "best.pt"
|
1171 |
+
save_model(p_model, model, args.target_columns)
|
1172 |
+
logger.info(f"Best model saved to '{p_model}'")
|
1173 |
+
|
1174 |
+
if args.remove_checkpoints:
|
1175 |
+
temp_dir.cleanup()
|
1176 |
+
|
1177 |
+
|
1178 |
+
def evaluate_and_save_predictions(preds, test_loader, metrics, model_output_dir, args):
|
1179 |
+
if isinstance(test_loader.dataset, MulticomponentDataset):
|
1180 |
+
test_dset = test_loader.dataset.datasets[0]
|
1181 |
+
else:
|
1182 |
+
test_dset = test_loader.dataset
|
1183 |
+
targets = test_dset.Y
|
1184 |
+
mask = torch.from_numpy(np.isfinite(targets))
|
1185 |
+
targets = np.nan_to_num(targets, nan=0.0)
|
1186 |
+
weights = torch.ones(len(test_dset))
|
1187 |
+
lt_mask = torch.from_numpy(test_dset.lt_mask) if test_dset.lt_mask[0] is not None else None
|
1188 |
+
gt_mask = torch.from_numpy(test_dset.gt_mask) if test_dset.gt_mask[0] is not None else None
|
1189 |
+
|
1190 |
+
individual_scores = dict()
|
1191 |
+
for metric in metrics:
|
1192 |
+
individual_scores[metric.alias] = []
|
1193 |
+
for i, col in enumerate(args.target_columns):
|
1194 |
+
if "multiclass" in args.task_type:
|
1195 |
+
preds_slice = torch.from_numpy(preds[:, i : i + 1, :])
|
1196 |
+
targets_slice = torch.from_numpy(targets[:, i : i + 1])
|
1197 |
+
else:
|
1198 |
+
preds_slice = torch.from_numpy(preds[:, i : i + 1])
|
1199 |
+
targets_slice = torch.from_numpy(targets[:, i : i + 1])
|
1200 |
+
preds_loss = metric(
|
1201 |
+
preds_slice,
|
1202 |
+
targets_slice,
|
1203 |
+
mask[:, i : i + 1],
|
1204 |
+
weights,
|
1205 |
+
lt_mask[:, i] if lt_mask is not None else None,
|
1206 |
+
gt_mask[:, i] if gt_mask is not None else None,
|
1207 |
+
)
|
1208 |
+
individual_scores[metric.alias].append(preds_loss)
|
1209 |
+
|
1210 |
+
logger.info("Test Set results:")
|
1211 |
+
for metric in metrics:
|
1212 |
+
avg_loss = sum(individual_scores[metric.alias]) / len(individual_scores[metric.alias])
|
1213 |
+
logger.info(f"test/{metric.alias}: {avg_loss}")
|
1214 |
+
|
1215 |
+
if args.show_individual_scores:
|
1216 |
+
logger.info("Entire Test Set individual results:")
|
1217 |
+
for metric in metrics:
|
1218 |
+
for i, col in enumerate(args.target_columns):
|
1219 |
+
logger.info(f"test/{col}/{metric.alias}: {individual_scores[metric.alias][i]}")
|
1220 |
+
|
1221 |
+
names = test_loader.dataset.names
|
1222 |
+
if isinstance(test_loader.dataset, MulticomponentDataset):
|
1223 |
+
namess = list(zip(*names))
|
1224 |
+
else:
|
1225 |
+
namess = [names]
|
1226 |
+
|
1227 |
+
columns = args.input_columns + args.target_columns
|
1228 |
+
if "multiclass" in args.task_type:
|
1229 |
+
columns = columns + [f"{col}_prob" for col in args.target_columns]
|
1230 |
+
formatted_probability_strings = np.apply_along_axis(
|
1231 |
+
lambda x: ",".join(map(str, x)), 2, preds
|
1232 |
+
)
|
1233 |
+
predicted_class_labels = preds.argmax(axis=-1)
|
1234 |
+
df_preds = pd.DataFrame(
|
1235 |
+
list(zip(*namess, *predicted_class_labels.T, *formatted_probability_strings.T)),
|
1236 |
+
columns=columns,
|
1237 |
+
)
|
1238 |
+
else:
|
1239 |
+
df_preds = pd.DataFrame(list(zip(*namess, *preds.T)), columns=columns)
|
1240 |
+
df_preds.to_csv(model_output_dir / "test_predictions.csv", index=False)
|
1241 |
+
|
1242 |
+
|
1243 |
+
def main(args):
|
1244 |
+
format_kwargs = dict(
|
1245 |
+
no_header_row=args.no_header_row,
|
1246 |
+
smiles_cols=args.smiles_columns,
|
1247 |
+
rxn_cols=args.reaction_columns,
|
1248 |
+
target_cols=args.target_columns,
|
1249 |
+
ignore_cols=args.ignore_columns,
|
1250 |
+
splits_col=args.splits_column,
|
1251 |
+
weight_col=args.weight_column,
|
1252 |
+
bounded=args.loss_function is not None and "bounded" in args.loss_function,
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
featurization_kwargs = dict(
|
1256 |
+
molecule_featurizers=args.molecule_featurizers,
|
1257 |
+
keep_h=args.keep_h,
|
1258 |
+
add_h=args.add_h,
|
1259 |
+
ignore_chirality=args.ignore_chirality,
|
1260 |
+
)
|
1261 |
+
|
1262 |
+
splits = build_splits(args, format_kwargs, featurization_kwargs)
|
1263 |
+
|
1264 |
+
for replicate_idx, (train_data, val_data, test_data) in enumerate(zip(*splits)):
|
1265 |
+
if args.num_replicates == 1:
|
1266 |
+
output_dir = args.output_dir
|
1267 |
+
else:
|
1268 |
+
output_dir = args.output_dir / f"replicate_{replicate_idx}"
|
1269 |
+
|
1270 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
1271 |
+
|
1272 |
+
train_dset, val_dset, test_dset = build_datasets(args, train_data, val_data, test_data)
|
1273 |
+
|
1274 |
+
if args.save_smiles_splits:
|
1275 |
+
save_smiles_splits(args, output_dir, train_dset, val_dset, test_dset)
|
1276 |
+
|
1277 |
+
if args.checkpoint or args.model_frzn is not None:
|
1278 |
+
model_paths = find_models(args.checkpoint)
|
1279 |
+
if len(model_paths) > 1:
|
1280 |
+
logger.warning(
|
1281 |
+
"Multiple checkpoint files were loaded, but only the scalers from "
|
1282 |
+
f"{model_paths[0]} are used. It is assumed that all models provided have the "
|
1283 |
+
"same data scalings, meaning they were trained on the same data."
|
1284 |
+
)
|
1285 |
+
model_path = model_paths[0] if args.checkpoint else args.model_frzn
|
1286 |
+
load_and_use_pretrained_model_scalers(model_path, train_dset, val_dset)
|
1287 |
+
input_transforms = (None, None, None)
|
1288 |
+
output_transform = None
|
1289 |
+
else:
|
1290 |
+
input_transforms = normalize_inputs(train_dset, val_dset, args)
|
1291 |
+
|
1292 |
+
if "regression" in args.task_type:
|
1293 |
+
output_scaler = train_dset.normalize_targets()
|
1294 |
+
val_dset.normalize_targets(output_scaler)
|
1295 |
+
logger.info(
|
1296 |
+
f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}"
|
1297 |
+
)
|
1298 |
+
output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
|
1299 |
+
else:
|
1300 |
+
output_transform = None
|
1301 |
+
|
1302 |
+
if not args.no_cache:
|
1303 |
+
train_dset.cache = True
|
1304 |
+
val_dset.cache = True
|
1305 |
+
|
1306 |
+
train_loader = build_dataloader(
|
1307 |
+
train_dset,
|
1308 |
+
args.batch_size,
|
1309 |
+
args.num_workers,
|
1310 |
+
class_balance=args.class_balance,
|
1311 |
+
seed=args.data_seed,
|
1312 |
+
)
|
1313 |
+
if args.class_balance:
|
1314 |
+
logger.debug(
|
1315 |
+
f"With `--class-balance`, effective train size = {len(train_loader.sampler)}"
|
1316 |
+
)
|
1317 |
+
val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
|
1318 |
+
if test_dset is not None:
|
1319 |
+
test_loader = build_dataloader(
|
1320 |
+
test_dset, args.batch_size, args.num_workers, shuffle=False
|
1321 |
+
)
|
1322 |
+
else:
|
1323 |
+
test_loader = None
|
1324 |
+
|
1325 |
+
train_model(
|
1326 |
+
args,
|
1327 |
+
train_loader,
|
1328 |
+
val_loader,
|
1329 |
+
test_loader,
|
1330 |
+
output_dir,
|
1331 |
+
output_transform,
|
1332 |
+
input_transforms,
|
1333 |
+
)
|
1334 |
+
|
1335 |
+
|
1336 |
+
if __name__ == "__main__":
|
1337 |
+
# TODO: update this old code or remove it.
|
1338 |
+
parser = ArgumentParser()
|
1339 |
+
parser = TrainSubcommand.add_args(parser)
|
1340 |
+
|
1341 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
1342 |
+
args = parser.parse_args()
|
1343 |
+
TrainSubcommand.func(args)
|
chemprop-updated/chemprop/cli/utils/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .actions import LookupAction
|
2 |
+
from .args import bounded
|
3 |
+
from .command import Subcommand
|
4 |
+
from .parsing import (
|
5 |
+
build_data_from_files,
|
6 |
+
get_column_names,
|
7 |
+
make_datapoints,
|
8 |
+
make_dataset,
|
9 |
+
parse_indices,
|
10 |
+
)
|
11 |
+
from .utils import _pop_attr, _pop_attr_d, pop_attr
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"bounded",
|
15 |
+
"LookupAction",
|
16 |
+
"Subcommand",
|
17 |
+
"build_data_from_files",
|
18 |
+
"make_datapoints",
|
19 |
+
"make_dataset",
|
20 |
+
"get_column_names",
|
21 |
+
"parse_indices",
|
22 |
+
"actions",
|
23 |
+
"args",
|
24 |
+
"command",
|
25 |
+
"parsing",
|
26 |
+
"utils",
|
27 |
+
"pop_attr",
|
28 |
+
"_pop_attr",
|
29 |
+
"_pop_attr_d",
|
30 |
+
]
|
chemprop-updated/chemprop/cli/utils/actions.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import _StoreAction
|
2 |
+
from typing import Any, Mapping
|
3 |
+
|
4 |
+
|
5 |
+
def LookupAction(obj: Mapping[str, Any]):
|
6 |
+
class LookupAction_(_StoreAction):
|
7 |
+
def __init__(self, option_strings, dest, default=None, choices=None, **kwargs):
|
8 |
+
if default not in obj.keys() and default is not None:
|
9 |
+
raise ValueError(
|
10 |
+
f"Invalid value for arg 'default': '{default}'. "
|
11 |
+
f"Expected one of {tuple(obj.keys())}"
|
12 |
+
)
|
13 |
+
|
14 |
+
kwargs["choices"] = choices if choices is not None else obj.keys()
|
15 |
+
kwargs["default"] = default
|
16 |
+
|
17 |
+
super().__init__(option_strings, dest, **kwargs)
|
18 |
+
|
19 |
+
return LookupAction_
|
chemprop-updated/chemprop/cli/utils/args.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
__all__ = ["bounded"]
|
4 |
+
|
5 |
+
|
6 |
+
def bounded(lo: float | None = None, hi: float | None = None):
|
7 |
+
if lo is None and hi is None:
|
8 |
+
raise ValueError("No bounds provided!")
|
9 |
+
|
10 |
+
def decorator(f):
|
11 |
+
@functools.wraps(f)
|
12 |
+
def wrapper(*args, **kwargs):
|
13 |
+
x = f(*args, **kwargs)
|
14 |
+
|
15 |
+
if (lo is not None and hi is not None) and not lo <= x <= hi:
|
16 |
+
raise ValueError(f"Parsed value outside of range [{lo}, {hi}]! got: {x}")
|
17 |
+
if hi is not None and x > hi:
|
18 |
+
raise ValueError(f"Parsed value below {hi}! got: {x}")
|
19 |
+
if lo is not None and x < lo:
|
20 |
+
raise ValueError(f"Parsed value above {lo}]! got: {x}")
|
21 |
+
|
22 |
+
return x
|
23 |
+
|
24 |
+
return wrapper
|
25 |
+
|
26 |
+
return decorator
|
27 |
+
|
28 |
+
|
29 |
+
def uppercase(x: str):
|
30 |
+
return x.upper()
|
31 |
+
|
32 |
+
|
33 |
+
def lowercase(x: str):
|
34 |
+
return x.lower()
|
chemprop-updated/chemprop/cli/utils/command.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from argparse import ArgumentParser, Namespace, _SubParsersAction
|
3 |
+
|
4 |
+
|
5 |
+
class Subcommand(ABC):
|
6 |
+
COMMAND: str
|
7 |
+
HELP: str | None = None
|
8 |
+
|
9 |
+
@classmethod
|
10 |
+
def add(cls, subparsers: _SubParsersAction, parents) -> ArgumentParser:
|
11 |
+
parser = subparsers.add_parser(cls.COMMAND, help=cls.HELP, parents=parents)
|
12 |
+
cls.add_args(parser).set_defaults(func=cls.func)
|
13 |
+
|
14 |
+
return parser
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
@abstractmethod
|
18 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
19 |
+
pass
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
@abstractmethod
|
23 |
+
def func(cls, args: Namespace):
|
24 |
+
pass
|
chemprop-updated/chemprop/cli/utils/parsing.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from os import PathLike
|
3 |
+
from typing import Literal, Mapping, Sequence
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
|
9 |
+
from chemprop.data.datasets import MoleculeDataset, ReactionDataset
|
10 |
+
from chemprop.featurizers.atom import get_multi_hot_atom_featurizer
|
11 |
+
from chemprop.featurizers.bond import MultiHotBondFeaturizer, RIGRBondFeaturizer
|
12 |
+
from chemprop.featurizers.molecule import MoleculeFeaturizerRegistry
|
13 |
+
from chemprop.featurizers.molgraph import (
|
14 |
+
CondensedGraphOfReactionFeaturizer,
|
15 |
+
SimpleMoleculeMolGraphFeaturizer,
|
16 |
+
)
|
17 |
+
from chemprop.utils import make_mol
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def parse_csv(
|
23 |
+
path: PathLike,
|
24 |
+
smiles_cols: Sequence[str] | None,
|
25 |
+
rxn_cols: Sequence[str] | None,
|
26 |
+
target_cols: Sequence[str] | None,
|
27 |
+
ignore_cols: Sequence[str] | None,
|
28 |
+
splits_col: str | None,
|
29 |
+
weight_col: str | None,
|
30 |
+
bounded: bool = False,
|
31 |
+
no_header_row: bool = False,
|
32 |
+
):
|
33 |
+
df = pd.read_csv(path, header=None if no_header_row else "infer", index_col=False)
|
34 |
+
|
35 |
+
if smiles_cols is not None and rxn_cols is not None:
|
36 |
+
smiss = df[smiles_cols].T.values.tolist()
|
37 |
+
rxnss = df[rxn_cols].T.values.tolist()
|
38 |
+
input_cols = [*smiles_cols, *rxn_cols]
|
39 |
+
elif smiles_cols is not None and rxn_cols is None:
|
40 |
+
smiss = df[smiles_cols].T.values.tolist()
|
41 |
+
rxnss = None
|
42 |
+
input_cols = smiles_cols
|
43 |
+
elif smiles_cols is None and rxn_cols is not None:
|
44 |
+
smiss = None
|
45 |
+
rxnss = df[rxn_cols].T.values.tolist()
|
46 |
+
input_cols = rxn_cols
|
47 |
+
else:
|
48 |
+
smiss = df.iloc[:, [0]].T.values.tolist()
|
49 |
+
rxnss = None
|
50 |
+
input_cols = [df.columns[0]]
|
51 |
+
|
52 |
+
if target_cols is None:
|
53 |
+
target_cols = list(
|
54 |
+
column
|
55 |
+
for column in df.columns
|
56 |
+
if column
|
57 |
+
not in set( # if splits or weight is None, df.columns will never have None
|
58 |
+
input_cols + (ignore_cols or []) + [splits_col] + [weight_col]
|
59 |
+
)
|
60 |
+
)
|
61 |
+
|
62 |
+
Y = df[target_cols]
|
63 |
+
weights = None if weight_col is None else df[weight_col].to_numpy(np.single)
|
64 |
+
|
65 |
+
if bounded:
|
66 |
+
Y = Y.astype(str)
|
67 |
+
lt_mask = Y.applymap(lambda x: "<" in x).to_numpy()
|
68 |
+
gt_mask = Y.applymap(lambda x: ">" in x).to_numpy()
|
69 |
+
Y = Y.applymap(lambda x: x.strip("<").strip(">")).to_numpy(np.single)
|
70 |
+
else:
|
71 |
+
Y = Y.to_numpy(np.single)
|
72 |
+
lt_mask = None
|
73 |
+
gt_mask = None
|
74 |
+
|
75 |
+
return smiss, rxnss, Y, weights, lt_mask, gt_mask
|
76 |
+
|
77 |
+
|
78 |
+
def get_column_names(
|
79 |
+
path: PathLike,
|
80 |
+
smiles_cols: Sequence[str] | None,
|
81 |
+
rxn_cols: Sequence[str] | None,
|
82 |
+
target_cols: Sequence[str] | None,
|
83 |
+
ignore_cols: Sequence[str] | None,
|
84 |
+
splits_col: str | None,
|
85 |
+
weight_col: str | None,
|
86 |
+
no_header_row: bool = False,
|
87 |
+
) -> tuple[list[str], list[str]]:
|
88 |
+
df_cols = pd.read_csv(path, index_col=False, nrows=0).columns.tolist()
|
89 |
+
|
90 |
+
if no_header_row:
|
91 |
+
return ["SMILES"], ["pred_" + str(i) for i in range((len(df_cols) - 1))]
|
92 |
+
|
93 |
+
input_cols = (smiles_cols or []) + (rxn_cols or [])
|
94 |
+
|
95 |
+
if len(input_cols) == 0:
|
96 |
+
input_cols = [df_cols[0]]
|
97 |
+
|
98 |
+
if target_cols is None:
|
99 |
+
target_cols = list(
|
100 |
+
column
|
101 |
+
for column in df_cols
|
102 |
+
if column
|
103 |
+
not in set(
|
104 |
+
input_cols + (ignore_cols or []) + ([splits_col] or []) + ([weight_col] or [])
|
105 |
+
)
|
106 |
+
)
|
107 |
+
|
108 |
+
return input_cols, target_cols
|
109 |
+
|
110 |
+
|
111 |
+
def make_datapoints(
|
112 |
+
smiss: list[list[str]] | None,
|
113 |
+
rxnss: list[list[str]] | None,
|
114 |
+
Y: np.ndarray,
|
115 |
+
weights: np.ndarray | None,
|
116 |
+
lt_mask: np.ndarray | None,
|
117 |
+
gt_mask: np.ndarray | None,
|
118 |
+
X_d: np.ndarray | None,
|
119 |
+
V_fss: list[list[np.ndarray] | list[None]] | None,
|
120 |
+
E_fss: list[list[np.ndarray] | list[None]] | None,
|
121 |
+
V_dss: list[list[np.ndarray] | list[None]] | None,
|
122 |
+
molecule_featurizers: list[str] | None,
|
123 |
+
keep_h: bool,
|
124 |
+
add_h: bool,
|
125 |
+
ignore_chirality: bool,
|
126 |
+
) -> tuple[list[list[MoleculeDatapoint]], list[list[ReactionDatapoint]]]:
|
127 |
+
"""Make the :class:`MoleculeDatapoint`s and :class:`ReactionDatapoint`s for a given
|
128 |
+
dataset.
|
129 |
+
|
130 |
+
Parameters
|
131 |
+
----------
|
132 |
+
smiss : list[list[str]] | None
|
133 |
+
a list of ``j`` lists of ``n`` SMILES strings, where ``j`` is the number of molecules per
|
134 |
+
datapoint and ``n`` is the number of datapoints. If ``None``, the corresponding list of
|
135 |
+
:class:`MoleculeDatapoint`\s will be empty.
|
136 |
+
rxnss : list[list[str]] | None
|
137 |
+
a list of ``k`` lists of ``n`` reaction SMILES strings, where ``k`` is the number of
|
138 |
+
reactions per datapoint. If ``None``, the corresponding list of :class:`ReactionDatapoint`\s
|
139 |
+
will be empty.
|
140 |
+
Y : np.ndarray
|
141 |
+
the target values of shape ``n x m``, where ``m`` is the number of targets
|
142 |
+
weights : np.ndarray | None
|
143 |
+
the weights of the datapoints to use in the loss function of shape ``n x m``. If ``None``,
|
144 |
+
the weights all default to 1.
|
145 |
+
lt_mask : np.ndarray | None
|
146 |
+
a boolean mask of shape ``n x m`` indicating whether the targets are less than inequality
|
147 |
+
targets. If ``None``, ``lt_mask`` for all datapoints will be ``None``.
|
148 |
+
gt_mask : np.ndarray | None
|
149 |
+
a boolean mask of shape ``n x m`` indicating whether the targets are greater than inequality
|
150 |
+
targets. If ``None``, ``gt_mask`` for all datapoints will be ``None``.
|
151 |
+
X_d : np.ndarray | None
|
152 |
+
the extra descriptors of shape ``n x p``, where ``p`` is the number of extra descriptors. If
|
153 |
+
``None``, ``x_d`` for all datapoints will be ``None``.
|
154 |
+
V_fss : list[list[np.ndarray] | list[None]] | None
|
155 |
+
a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x q_j``, where ``v_jn`` is
|
156 |
+
the number of atoms in the j-th molecule of the n-th datapoint and ``q_j`` is the number of
|
157 |
+
extra atom features used for the j-th molecules. Any of the ``j`` lists can be a list of
|
158 |
+
None values if the corresponding component does not use extra atom features. If ``None``,
|
159 |
+
``V_f`` for all datapoints will be ``None``.
|
160 |
+
E_fss : list[list[np.ndarray] | list[None]] | None
|
161 |
+
a list of ``j`` lists of ``n`` np.ndarrays each of shape ``e_jn x r_j``, where ``e_jn`` is
|
162 |
+
the number of bonds in the j-th molecule of the n-th datapoint and ``r_j`` is the number of
|
163 |
+
extra bond features used for the j-th molecules. Any of the ``j`` lists can be a list of
|
164 |
+
None values if the corresponding component does not use extra bond features. If ``None``,
|
165 |
+
``E_f`` for all datapoints will be ``None``.
|
166 |
+
V_dss : list[list[np.ndarray] | list[None]] | None
|
167 |
+
a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x s_j``, where ``s_j`` is
|
168 |
+
the number of extra atom descriptors used for the j-th molecules. Any of the ``j`` lists can
|
169 |
+
be a list of None values if the corresponding component does not use extra atom features. If
|
170 |
+
``None``, ``V_d`` for all datapoints will be ``None``.
|
171 |
+
molecule_featurizers : list[str] | None
|
172 |
+
a list of molecule featurizer names to generate additional molecule features to use as extra
|
173 |
+
descriptors. If there are multiple molecules per datapoint, the featurizers will be applied
|
174 |
+
to each molecule and concatenated. Note that a :code:`ReactionDatapoint` has two
|
175 |
+
RDKit :class:`~rdkit.Chem.Mol` objects, reactant(s) and product(s). Each
|
176 |
+
``molecule_featurizer`` will be applied to both of these objects.
|
177 |
+
keep_h : bool
|
178 |
+
whether to keep hydrogen atoms
|
179 |
+
add_h : bool
|
180 |
+
whether to add hydrogen atoms
|
181 |
+
ignore_chirality : bool
|
182 |
+
whether to ignore chirality information
|
183 |
+
|
184 |
+
Returns
|
185 |
+
-------
|
186 |
+
list[list[MoleculeDatapoint]]
|
187 |
+
a list of ``j`` lists of ``n`` :class:`MoleculeDatapoint`\s
|
188 |
+
list[list[ReactionDatapoint]]
|
189 |
+
a list of ``k`` lists of ``n`` :class:`ReactionDatapoint`\s
|
190 |
+
.. note::
|
191 |
+
either ``j`` or ``k`` may be 0, in which case the corresponding list will be empty.
|
192 |
+
|
193 |
+
Raises
|
194 |
+
------
|
195 |
+
ValueError
|
196 |
+
if both ``smiss`` and ``rxnss`` are ``None``.
|
197 |
+
if ``smiss`` and ``rxnss`` are both given and have different lengths.
|
198 |
+
"""
|
199 |
+
if smiss is None and rxnss is None:
|
200 |
+
raise ValueError("args 'smiss' and 'rnxss' were both `None`!")
|
201 |
+
elif rxnss is None:
|
202 |
+
N = len(smiss[0])
|
203 |
+
rxnss = []
|
204 |
+
elif smiss is None:
|
205 |
+
N = len(rxnss[0])
|
206 |
+
smiss = []
|
207 |
+
elif len(smiss[0]) != len(rxnss[0]):
|
208 |
+
raise ValueError(
|
209 |
+
f"args 'smiss' and 'rxnss' must have same length! got {len(smiss[0])} and {len(rxnss[0])}"
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
N = len(smiss[0])
|
213 |
+
|
214 |
+
if len(smiss) > 0:
|
215 |
+
molss = [[make_mol(smi, keep_h, add_h, ignore_chirality) for smi in smis] for smis in smiss]
|
216 |
+
if len(rxnss) > 0:
|
217 |
+
rctss = [
|
218 |
+
[
|
219 |
+
make_mol(
|
220 |
+
f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi, keep_h, add_h, ignore_chirality
|
221 |
+
)
|
222 |
+
for rct_smi, agt_smi, _ in (rxn.split(">") for rxn in rxns)
|
223 |
+
]
|
224 |
+
for rxns in rxnss
|
225 |
+
]
|
226 |
+
pdtss = [
|
227 |
+
[
|
228 |
+
make_mol(pdt_smi, keep_h, add_h, ignore_chirality)
|
229 |
+
for _, _, pdt_smi in (rxn.split(">") for rxn in rxns)
|
230 |
+
]
|
231 |
+
for rxns in rxnss
|
232 |
+
]
|
233 |
+
|
234 |
+
weights = np.ones(N, dtype=np.single) if weights is None else weights
|
235 |
+
gt_mask = [None] * N if gt_mask is None else gt_mask
|
236 |
+
lt_mask = [None] * N if lt_mask is None else lt_mask
|
237 |
+
|
238 |
+
n_mols = len(smiss) if smiss else 0
|
239 |
+
V_fss = [[None] * N] * n_mols if V_fss is None else V_fss
|
240 |
+
E_fss = [[None] * N] * n_mols if E_fss is None else E_fss
|
241 |
+
V_dss = [[None] * N] * n_mols if V_dss is None else V_dss
|
242 |
+
|
243 |
+
if X_d is None and molecule_featurizers is None:
|
244 |
+
X_d = [None] * N
|
245 |
+
elif molecule_featurizers is None:
|
246 |
+
pass
|
247 |
+
else:
|
248 |
+
molecule_featurizers = [MoleculeFeaturizerRegistry[mf]() for mf in molecule_featurizers]
|
249 |
+
|
250 |
+
if len(smiss) > 0:
|
251 |
+
mol_descriptors = np.hstack(
|
252 |
+
[
|
253 |
+
np.vstack([np.hstack([mf(mol) for mf in molecule_featurizers]) for mol in mols])
|
254 |
+
for mols in molss
|
255 |
+
]
|
256 |
+
)
|
257 |
+
if X_d is None:
|
258 |
+
X_d = mol_descriptors
|
259 |
+
else:
|
260 |
+
X_d = np.hstack([X_d, mol_descriptors])
|
261 |
+
|
262 |
+
if len(rxnss) > 0:
|
263 |
+
rct_pdt_descriptors = np.hstack(
|
264 |
+
[
|
265 |
+
np.vstack(
|
266 |
+
[
|
267 |
+
np.hstack(
|
268 |
+
[mf(mol) for mf in molecule_featurizers for mol in (rct, pdt)]
|
269 |
+
)
|
270 |
+
for rct, pdt in zip(rcts, pdts)
|
271 |
+
]
|
272 |
+
)
|
273 |
+
for rcts, pdts in zip(rctss, pdtss)
|
274 |
+
]
|
275 |
+
)
|
276 |
+
if X_d is None:
|
277 |
+
X_d = rct_pdt_descriptors
|
278 |
+
else:
|
279 |
+
X_d = np.hstack([X_d, rct_pdt_descriptors])
|
280 |
+
|
281 |
+
mol_data = [
|
282 |
+
[
|
283 |
+
MoleculeDatapoint(
|
284 |
+
mol=molss[mol_idx][i],
|
285 |
+
name=smis[i],
|
286 |
+
y=Y[i],
|
287 |
+
weight=weights[i],
|
288 |
+
gt_mask=gt_mask[i],
|
289 |
+
lt_mask=lt_mask[i],
|
290 |
+
x_d=X_d[i],
|
291 |
+
x_phase=None,
|
292 |
+
V_f=V_fss[mol_idx][i],
|
293 |
+
E_f=E_fss[mol_idx][i],
|
294 |
+
V_d=V_dss[mol_idx][i],
|
295 |
+
)
|
296 |
+
for i in range(N)
|
297 |
+
]
|
298 |
+
for mol_idx, smis in enumerate(smiss)
|
299 |
+
]
|
300 |
+
rxn_data = [
|
301 |
+
[
|
302 |
+
ReactionDatapoint(
|
303 |
+
rct=rctss[rxn_idx][i],
|
304 |
+
pdt=pdtss[rxn_idx][i],
|
305 |
+
name=rxns[i],
|
306 |
+
y=Y[i],
|
307 |
+
weight=weights[i],
|
308 |
+
gt_mask=gt_mask[i],
|
309 |
+
lt_mask=lt_mask[i],
|
310 |
+
x_d=X_d[i],
|
311 |
+
x_phase=None,
|
312 |
+
)
|
313 |
+
for i in range(N)
|
314 |
+
]
|
315 |
+
for rxn_idx, rxns in enumerate(rxnss)
|
316 |
+
]
|
317 |
+
|
318 |
+
return mol_data, rxn_data
|
319 |
+
|
320 |
+
|
321 |
+
def build_data_from_files(
|
322 |
+
p_data: PathLike,
|
323 |
+
no_header_row: bool,
|
324 |
+
smiles_cols: Sequence[str] | None,
|
325 |
+
rxn_cols: Sequence[str] | None,
|
326 |
+
target_cols: Sequence[str] | None,
|
327 |
+
ignore_cols: Sequence[str] | None,
|
328 |
+
splits_col: str | None,
|
329 |
+
weight_col: str | None,
|
330 |
+
bounded: bool,
|
331 |
+
p_descriptors: PathLike,
|
332 |
+
p_atom_feats: dict[int, PathLike],
|
333 |
+
p_bond_feats: dict[int, PathLike],
|
334 |
+
p_atom_descs: dict[int, PathLike],
|
335 |
+
**featurization_kwargs: Mapping,
|
336 |
+
) -> list[list[MoleculeDatapoint] | list[ReactionDatapoint]]:
|
337 |
+
smiss, rxnss, Y, weights, lt_mask, gt_mask = parse_csv(
|
338 |
+
p_data,
|
339 |
+
smiles_cols,
|
340 |
+
rxn_cols,
|
341 |
+
target_cols,
|
342 |
+
ignore_cols,
|
343 |
+
splits_col,
|
344 |
+
weight_col,
|
345 |
+
bounded,
|
346 |
+
no_header_row,
|
347 |
+
)
|
348 |
+
n_molecules = len(smiss) if smiss is not None else 0
|
349 |
+
n_datapoints = len(Y)
|
350 |
+
|
351 |
+
X_ds = load_input_feats_and_descs(p_descriptors, None, None, feat_desc="X_d")
|
352 |
+
V_fss = load_input_feats_and_descs(p_atom_feats, n_molecules, n_datapoints, feat_desc="V_f")
|
353 |
+
E_fss = load_input_feats_and_descs(p_bond_feats, n_molecules, n_datapoints, feat_desc="E_f")
|
354 |
+
V_dss = load_input_feats_and_descs(p_atom_descs, n_molecules, n_datapoints, feat_desc="V_d")
|
355 |
+
|
356 |
+
mol_data, rxn_data = make_datapoints(
|
357 |
+
smiss,
|
358 |
+
rxnss,
|
359 |
+
Y,
|
360 |
+
weights,
|
361 |
+
lt_mask,
|
362 |
+
gt_mask,
|
363 |
+
X_ds,
|
364 |
+
V_fss,
|
365 |
+
E_fss,
|
366 |
+
V_dss,
|
367 |
+
**featurization_kwargs,
|
368 |
+
)
|
369 |
+
|
370 |
+
return mol_data + rxn_data
|
371 |
+
|
372 |
+
|
373 |
+
def load_input_feats_and_descs(
|
374 |
+
paths: dict[int, PathLike] | PathLike,
|
375 |
+
n_molecules: int | None,
|
376 |
+
n_datapoints: int | None,
|
377 |
+
feat_desc: str,
|
378 |
+
):
|
379 |
+
if paths is None:
|
380 |
+
return None
|
381 |
+
|
382 |
+
match feat_desc:
|
383 |
+
case "X_d":
|
384 |
+
path = paths
|
385 |
+
loaded_feature = np.load(path)
|
386 |
+
features = loaded_feature["arr_0"]
|
387 |
+
|
388 |
+
case _:
|
389 |
+
for index in paths:
|
390 |
+
if index >= n_molecules:
|
391 |
+
raise ValueError(
|
392 |
+
f"For {n_molecules} molecules, atom/bond features/descriptors can only be "
|
393 |
+
f"specified for indices 0-{n_molecules - 1}! Got index {index}."
|
394 |
+
)
|
395 |
+
|
396 |
+
features = []
|
397 |
+
for idx in range(n_molecules):
|
398 |
+
path = paths.get(idx, None)
|
399 |
+
|
400 |
+
if path is not None:
|
401 |
+
loaded_feature = np.load(path)
|
402 |
+
loaded_feature = [
|
403 |
+
loaded_feature[f"arr_{i}"] for i in range(len(loaded_feature))
|
404 |
+
]
|
405 |
+
else:
|
406 |
+
loaded_feature = [None] * n_datapoints
|
407 |
+
|
408 |
+
features.append(loaded_feature)
|
409 |
+
return features
|
410 |
+
|
411 |
+
|
412 |
+
def make_dataset(
|
413 |
+
data: Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint],
|
414 |
+
reaction_mode: str,
|
415 |
+
multi_hot_atom_featurizer_mode: Literal["V1", "V2", "ORGANIC", "RIGR"] = "V2",
|
416 |
+
) -> MoleculeDataset | ReactionDataset:
|
417 |
+
atom_featurizer = get_multi_hot_atom_featurizer(multi_hot_atom_featurizer_mode)
|
418 |
+
match multi_hot_atom_featurizer_mode:
|
419 |
+
case "RIGR":
|
420 |
+
bond_featurizer = RIGRBondFeaturizer()
|
421 |
+
case "V1" | "V2" | "ORGANIC":
|
422 |
+
bond_featurizer = MultiHotBondFeaturizer()
|
423 |
+
case _:
|
424 |
+
raise TypeError(
|
425 |
+
f"Unsupported atom featurizer mode '{multi_hot_atom_featurizer_mode=}'!"
|
426 |
+
)
|
427 |
+
|
428 |
+
if isinstance(data[0], MoleculeDatapoint):
|
429 |
+
extra_atom_fdim = data[0].V_f.shape[1] if data[0].V_f is not None else 0
|
430 |
+
extra_bond_fdim = data[0].E_f.shape[1] if data[0].E_f is not None else 0
|
431 |
+
featurizer = SimpleMoleculeMolGraphFeaturizer(
|
432 |
+
atom_featurizer=atom_featurizer,
|
433 |
+
bond_featurizer=bond_featurizer,
|
434 |
+
extra_atom_fdim=extra_atom_fdim,
|
435 |
+
extra_bond_fdim=extra_bond_fdim,
|
436 |
+
)
|
437 |
+
return MoleculeDataset(data, featurizer)
|
438 |
+
|
439 |
+
featurizer = CondensedGraphOfReactionFeaturizer(
|
440 |
+
mode_=reaction_mode, atom_featurizer=atom_featurizer
|
441 |
+
)
|
442 |
+
|
443 |
+
return ReactionDataset(data, featurizer)
|
444 |
+
|
445 |
+
|
446 |
+
def parse_indices(idxs):
|
447 |
+
"""Parses a string of indices into a list of integers. e.g. '0,1,2-4' -> [0, 1, 2, 3, 4]"""
|
448 |
+
if isinstance(idxs, str):
|
449 |
+
indices = []
|
450 |
+
for idx in idxs.split(","):
|
451 |
+
if "-" in idx:
|
452 |
+
start, end = map(int, idx.split("-"))
|
453 |
+
indices.extend(range(start, end + 1))
|
454 |
+
else:
|
455 |
+
indices.append(int(idx))
|
456 |
+
return indices
|
457 |
+
return idxs
|
chemprop-updated/chemprop/cli/utils/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
__all__ = ["pop_attr"]
|
4 |
+
|
5 |
+
|
6 |
+
def pop_attr(o: object, attr: str, *args) -> Any | None:
|
7 |
+
"""like ``pop()`` but for attribute maps"""
|
8 |
+
match len(args):
|
9 |
+
case 0:
|
10 |
+
return _pop_attr(o, attr)
|
11 |
+
case 1:
|
12 |
+
return _pop_attr_d(o, attr, args[0])
|
13 |
+
case _:
|
14 |
+
raise TypeError(f"Expected at most 2 arguments! got: {len(args)}")
|
15 |
+
|
16 |
+
|
17 |
+
def _pop_attr(o: object, attr: str) -> Any:
|
18 |
+
val = getattr(o, attr)
|
19 |
+
delattr(o, attr)
|
20 |
+
|
21 |
+
return val
|
22 |
+
|
23 |
+
|
24 |
+
def _pop_attr_d(o: object, attr: str, default: Any | None = None) -> Any | None:
|
25 |
+
try:
|
26 |
+
val = getattr(o, attr)
|
27 |
+
delattr(o, attr)
|
28 |
+
except AttributeError:
|
29 |
+
val = default
|
30 |
+
|
31 |
+
return val
|
chemprop-updated/chemprop/conf.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Global configuration variables for chemprop"""
|
2 |
+
|
3 |
+
from chemprop.featurizers.molgraph.molecule import SimpleMoleculeMolGraphFeaturizer
|
4 |
+
|
5 |
+
DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM = SimpleMoleculeMolGraphFeaturizer().shape
|
6 |
+
DEFAULT_HIDDEN_DIM = 300
|
chemprop-updated/chemprop/data/__init__.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .collate import (
|
2 |
+
BatchMolGraph,
|
3 |
+
MulticomponentTrainingBatch,
|
4 |
+
TrainingBatch,
|
5 |
+
collate_batch,
|
6 |
+
collate_multicomponent,
|
7 |
+
)
|
8 |
+
from .dataloader import build_dataloader
|
9 |
+
from .datapoints import MoleculeDatapoint, ReactionDatapoint
|
10 |
+
from .datasets import (
|
11 |
+
Datum,
|
12 |
+
MoleculeDataset,
|
13 |
+
MolGraphDataset,
|
14 |
+
MulticomponentDataset,
|
15 |
+
ReactionDataset,
|
16 |
+
)
|
17 |
+
from .molgraph import MolGraph
|
18 |
+
from .samplers import ClassBalanceSampler, SeededSampler
|
19 |
+
from .splitting import SplitType, make_split_indices, split_data_by_indices
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"BatchMolGraph",
|
23 |
+
"TrainingBatch",
|
24 |
+
"collate_batch",
|
25 |
+
"MulticomponentTrainingBatch",
|
26 |
+
"collate_multicomponent",
|
27 |
+
"build_dataloader",
|
28 |
+
"MoleculeDatapoint",
|
29 |
+
"ReactionDatapoint",
|
30 |
+
"MoleculeDataset",
|
31 |
+
"ReactionDataset",
|
32 |
+
"Datum",
|
33 |
+
"MulticomponentDataset",
|
34 |
+
"MolGraphDataset",
|
35 |
+
"MolGraph",
|
36 |
+
"ClassBalanceSampler",
|
37 |
+
"SeededSampler",
|
38 |
+
"SplitType",
|
39 |
+
"make_split_indices",
|
40 |
+
"split_data_by_indices",
|
41 |
+
]
|
chemprop-updated/chemprop/data/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.43 kB). View file
|
|
chemprop-updated/chemprop/data/__pycache__/data.cpython-37.pyc
ADDED
Binary file (43.5 kB). View file
|
|
chemprop-updated/chemprop/data/__pycache__/scaffold.cpython-37.pyc
ADDED
Binary file (7.37 kB). View file
|
|
chemprop-updated/chemprop/data/__pycache__/scaler.cpython-37.pyc
ADDED
Binary file (5.95 kB). View file
|
|
chemprop-updated/chemprop/data/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (35.1 kB). View file
|
|
chemprop-updated/chemprop/data/collate.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import InitVar, dataclass, field
|
2 |
+
from typing import Iterable, NamedTuple, Sequence
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
|
8 |
+
from chemprop.data.datasets import Datum
|
9 |
+
from chemprop.data.molgraph import MolGraph
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass(repr=False, eq=False, slots=True)
|
13 |
+
class BatchMolGraph:
|
14 |
+
"""A :class:`BatchMolGraph` represents a batch of individual :class:`MolGraph`\s.
|
15 |
+
|
16 |
+
It has all the attributes of a ``MolGraph`` with the addition of the ``batch`` attribute. This
|
17 |
+
class is intended for use with data loading, so it uses :obj:`~torch.Tensor`\s to store data
|
18 |
+
"""
|
19 |
+
|
20 |
+
mgs: InitVar[Sequence[MolGraph]]
|
21 |
+
"""A list of individual :class:`MolGraph`\s to be batched together"""
|
22 |
+
V: Tensor = field(init=False)
|
23 |
+
"""the atom feature matrix"""
|
24 |
+
E: Tensor = field(init=False)
|
25 |
+
"""the bond feature matrix"""
|
26 |
+
edge_index: Tensor = field(init=False)
|
27 |
+
"""an tensor of shape ``2 x E`` containing the edges of the graph in COO format"""
|
28 |
+
rev_edge_index: Tensor = field(init=False)
|
29 |
+
"""A tensor of shape ``E`` that maps from an edge index to the index of the source of the
|
30 |
+
reverse edge in the ``edge_index`` attribute."""
|
31 |
+
batch: Tensor = field(init=False)
|
32 |
+
"""the index of the parent :class:`MolGraph` in the batched graph"""
|
33 |
+
names: list[str] = field(init=False) # Add SMILES strings for the batch
|
34 |
+
|
35 |
+
__size: int = field(init=False)
|
36 |
+
|
37 |
+
def __post_init__(self, mgs: Sequence[MolGraph]):
|
38 |
+
self.__size = len(mgs)
|
39 |
+
|
40 |
+
Vs = []
|
41 |
+
Es = []
|
42 |
+
edge_indexes = []
|
43 |
+
rev_edge_indexes = []
|
44 |
+
batch_indexes = []
|
45 |
+
self.names = []
|
46 |
+
|
47 |
+
num_nodes = 0
|
48 |
+
num_edges = 0
|
49 |
+
for i, mg in enumerate(mgs):
|
50 |
+
Vs.append(mg.V)
|
51 |
+
Es.append(mg.E)
|
52 |
+
edge_indexes.append(mg.edge_index + num_nodes)
|
53 |
+
rev_edge_indexes.append(mg.rev_edge_index + num_edges)
|
54 |
+
batch_indexes.append([i] * len(mg.V))
|
55 |
+
self.names.append(mg.name)
|
56 |
+
|
57 |
+
num_nodes += mg.V.shape[0]
|
58 |
+
num_edges += mg.edge_index.shape[1]
|
59 |
+
|
60 |
+
self.V = torch.from_numpy(np.concatenate(Vs)).float()
|
61 |
+
self.E = torch.from_numpy(np.concatenate(Es)).float()
|
62 |
+
self.edge_index = torch.from_numpy(np.hstack(edge_indexes)).long()
|
63 |
+
self.rev_edge_index = torch.from_numpy(np.concatenate(rev_edge_indexes)).long()
|
64 |
+
self.batch = torch.tensor(np.concatenate(batch_indexes)).long()
|
65 |
+
|
66 |
+
def __len__(self) -> int:
|
67 |
+
"""the number of individual :class:`MolGraph`\s in this batch"""
|
68 |
+
return self.__size
|
69 |
+
|
70 |
+
def to(self, device: str | torch.device):
|
71 |
+
self.V = self.V.to(device)
|
72 |
+
self.E = self.E.to(device)
|
73 |
+
self.edge_index = self.edge_index.to(device)
|
74 |
+
self.rev_edge_index = self.rev_edge_index.to(device)
|
75 |
+
self.batch = self.batch.to(device)
|
76 |
+
|
77 |
+
|
78 |
+
class TrainingBatch(NamedTuple):
|
79 |
+
bmg: BatchMolGraph
|
80 |
+
V_d: Tensor | None
|
81 |
+
X_d: Tensor | None
|
82 |
+
Y: Tensor | None
|
83 |
+
w: Tensor
|
84 |
+
lt_mask: Tensor | None
|
85 |
+
gt_mask: Tensor | None
|
86 |
+
|
87 |
+
|
88 |
+
def collate_batch(batch: Iterable[Datum]) -> TrainingBatch:
|
89 |
+
mgs, V_ds, x_ds, ys, weights, lt_masks, gt_masks = zip(*batch)
|
90 |
+
|
91 |
+
return TrainingBatch(
|
92 |
+
BatchMolGraph(mgs),
|
93 |
+
None if V_ds[0] is None else torch.from_numpy(np.concatenate(V_ds)).float(),
|
94 |
+
None if x_ds[0] is None else torch.from_numpy(np.array(x_ds)).float(),
|
95 |
+
None if ys[0] is None else torch.from_numpy(np.array(ys)).float(),
|
96 |
+
torch.tensor(weights, dtype=torch.float).unsqueeze(1),
|
97 |
+
None if lt_masks[0] is None else torch.from_numpy(np.array(lt_masks)),
|
98 |
+
None if gt_masks[0] is None else torch.from_numpy(np.array(gt_masks)),
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
class MulticomponentTrainingBatch(NamedTuple):
|
103 |
+
bmgs: list[BatchMolGraph]
|
104 |
+
V_ds: list[Tensor | None]
|
105 |
+
X_d: Tensor | None
|
106 |
+
Y: Tensor | None
|
107 |
+
w: Tensor
|
108 |
+
lt_mask: Tensor | None
|
109 |
+
gt_mask: Tensor | None
|
110 |
+
|
111 |
+
|
112 |
+
def collate_multicomponent(batches: Iterable[Iterable[Datum]]) -> MulticomponentTrainingBatch:
|
113 |
+
tbs = [collate_batch(batch) for batch in zip(*batches)]
|
114 |
+
|
115 |
+
return MulticomponentTrainingBatch(
|
116 |
+
[tb.bmg for tb in tbs],
|
117 |
+
[tb.V_d for tb in tbs],
|
118 |
+
tbs[0].X_d,
|
119 |
+
tbs[0].Y,
|
120 |
+
tbs[0].w,
|
121 |
+
tbs[0].lt_mask,
|
122 |
+
tbs[0].gt_mask,
|
123 |
+
)
|
chemprop-updated/chemprop/data/dataloader.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
|
5 |
+
from chemprop.data.collate import collate_batch, collate_multicomponent
|
6 |
+
from chemprop.data.datasets import MoleculeDataset, MulticomponentDataset, ReactionDataset
|
7 |
+
from chemprop.data.samplers import ClassBalanceSampler, SeededSampler
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def build_dataloader(
|
13 |
+
dataset: MoleculeDataset | ReactionDataset | MulticomponentDataset,
|
14 |
+
batch_size: int = 64,
|
15 |
+
num_workers: int = 0,
|
16 |
+
class_balance: bool = False,
|
17 |
+
seed: int | None = None,
|
18 |
+
shuffle: bool = True,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
"""Return a :obj:`~torch.utils.data.DataLoader` for :class:`MolGraphDataset`\s
|
22 |
+
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
dataset : MoleculeDataset | ReactionDataset | MulticomponentDataset
|
26 |
+
The dataset containing the molecules or reactions to load.
|
27 |
+
batch_size : int, default=64
|
28 |
+
the batch size to load.
|
29 |
+
num_workers : int, default=0
|
30 |
+
the number of workers used to build batches.
|
31 |
+
class_balance : bool, default=False
|
32 |
+
Whether to perform class balancing (i.e., use an equal number of positive and negative
|
33 |
+
molecules). Class balance is only available for single task classification datasets. Set
|
34 |
+
shuffle to True in order to get a random subset of the larger class.
|
35 |
+
seed : int, default=None
|
36 |
+
the random seed to use for shuffling (only used when `shuffle` is `True`).
|
37 |
+
shuffle : bool, default=False
|
38 |
+
whether to shuffle the data during sampling.
|
39 |
+
"""
|
40 |
+
|
41 |
+
if class_balance:
|
42 |
+
sampler = ClassBalanceSampler(dataset.Y, seed, shuffle)
|
43 |
+
elif shuffle and seed is not None:
|
44 |
+
sampler = SeededSampler(len(dataset), seed)
|
45 |
+
else:
|
46 |
+
sampler = None
|
47 |
+
|
48 |
+
if isinstance(dataset, MulticomponentDataset):
|
49 |
+
collate_fn = collate_multicomponent
|
50 |
+
else:
|
51 |
+
collate_fn = collate_batch
|
52 |
+
|
53 |
+
if len(dataset) % batch_size == 1:
|
54 |
+
logger.warning(
|
55 |
+
f"Dropping last batch of size 1 to avoid issues with batch normalization \
|
56 |
+
(dataset size = {len(dataset)}, batch_size = {batch_size})"
|
57 |
+
)
|
58 |
+
drop_last = True
|
59 |
+
else:
|
60 |
+
drop_last = False
|
61 |
+
|
62 |
+
return DataLoader(
|
63 |
+
dataset,
|
64 |
+
batch_size,
|
65 |
+
sampler is None and shuffle,
|
66 |
+
sampler,
|
67 |
+
num_workers=num_workers,
|
68 |
+
collate_fn=collate_fn,
|
69 |
+
drop_last=drop_last,
|
70 |
+
**kwargs,
|
71 |
+
)
|
chemprop-updated/chemprop/data/datapoints.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from dataclasses import dataclass
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from rdkit.Chem import AllChem as Chem
|
7 |
+
|
8 |
+
from chemprop.featurizers import Featurizer
|
9 |
+
from chemprop.utils import make_mol
|
10 |
+
|
11 |
+
MoleculeFeaturizer = Featurizer[Chem.Mol, np.ndarray]
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass(slots=True)
|
15 |
+
class _DatapointMixin:
|
16 |
+
"""A mixin class for both molecule- and reaction- and multicomponent-type data"""
|
17 |
+
|
18 |
+
y: np.ndarray | None = None
|
19 |
+
"""the targets for the molecule with unknown targets indicated by `nan`s"""
|
20 |
+
weight: float = 1.0
|
21 |
+
"""the weight of this datapoint for the loss calculation."""
|
22 |
+
gt_mask: np.ndarray | None = None
|
23 |
+
"""Indicates whether the targets are an inequality regression target of the form `<x`"""
|
24 |
+
lt_mask: np.ndarray | None = None
|
25 |
+
"""Indicates whether the targets are an inequality regression target of the form `>x`"""
|
26 |
+
x_d: np.ndarray | None = None
|
27 |
+
"""A vector of length ``d_f`` containing additional features (e.g., Morgan fingerprint) that
|
28 |
+
will be concatenated to the global representation *after* aggregation"""
|
29 |
+
x_phase: list[float] = None
|
30 |
+
"""A one-hot vector indicating the phase of the data, as used in spectra data."""
|
31 |
+
name: str | None = None
|
32 |
+
"""A string identifier for the datapoint."""
|
33 |
+
|
34 |
+
def __post_init__(self):
|
35 |
+
NAN_TOKEN = 0
|
36 |
+
if self.x_d is not None:
|
37 |
+
self.x_d[np.isnan(self.x_d)] = NAN_TOKEN
|
38 |
+
|
39 |
+
@property
|
40 |
+
def t(self) -> int | None:
|
41 |
+
return len(self.y) if self.y is not None else None
|
42 |
+
|
43 |
+
|
44 |
+
@dataclass
|
45 |
+
class _MoleculeDatapointMixin:
|
46 |
+
mol: Chem.Mol
|
47 |
+
"""the molecule associated with this datapoint"""
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def from_smi(
|
51 |
+
cls,
|
52 |
+
smi: str,
|
53 |
+
*args,
|
54 |
+
keep_h: bool = False,
|
55 |
+
add_h: bool = False,
|
56 |
+
ignore_chirality: bool = False,
|
57 |
+
**kwargs,
|
58 |
+
) -> _MoleculeDatapointMixin:
|
59 |
+
mol = make_mol(smi, keep_h, add_h, ignore_chirality)
|
60 |
+
|
61 |
+
kwargs["name"] = smi if "name" not in kwargs else kwargs["name"]
|
62 |
+
|
63 |
+
return cls(mol, *args, **kwargs)
|
64 |
+
|
65 |
+
|
66 |
+
@dataclass
|
67 |
+
class MoleculeDatapoint(_DatapointMixin, _MoleculeDatapointMixin):
|
68 |
+
"""A :class:`MoleculeDatapoint` contains a single molecule and its associated features and targets."""
|
69 |
+
|
70 |
+
V_f: np.ndarray | None = None
|
71 |
+
"""a numpy array of shape ``V x d_vf``, where ``V`` is the number of atoms in the molecule, and
|
72 |
+
``d_vf`` is the number of additional features that will be concatenated to atom-level features
|
73 |
+
*before* message passing"""
|
74 |
+
E_f: np.ndarray | None = None
|
75 |
+
"""A numpy array of shape ``E x d_ef``, where ``E`` is the number of bonds in the molecule, and
|
76 |
+
``d_ef`` is the number of additional features containing additional features that will be
|
77 |
+
concatenated to bond-level features *before* message passing"""
|
78 |
+
V_d: np.ndarray | None = None
|
79 |
+
"""A numpy array of shape ``V x d_vd``, where ``V`` is the number of atoms in the molecule, and
|
80 |
+
``d_vd`` is the number of additional descriptors that will be concatenated to atom-level
|
81 |
+
descriptors *after* message passing"""
|
82 |
+
|
83 |
+
def __post_init__(self):
|
84 |
+
NAN_TOKEN = 0
|
85 |
+
if self.V_f is not None:
|
86 |
+
self.V_f[np.isnan(self.V_f)] = NAN_TOKEN
|
87 |
+
if self.E_f is not None:
|
88 |
+
self.E_f[np.isnan(self.E_f)] = NAN_TOKEN
|
89 |
+
if self.V_d is not None:
|
90 |
+
self.V_d[np.isnan(self.V_d)] = NAN_TOKEN
|
91 |
+
|
92 |
+
super().__post_init__()
|
93 |
+
|
94 |
+
def __len__(self) -> int:
|
95 |
+
return 1
|
96 |
+
|
97 |
+
|
98 |
+
@dataclass
|
99 |
+
class _ReactionDatapointMixin:
|
100 |
+
rct: Chem.Mol
|
101 |
+
"""the reactant associated with this datapoint"""
|
102 |
+
pdt: Chem.Mol
|
103 |
+
"""the product associated with this datapoint"""
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def from_smi(
|
107 |
+
cls,
|
108 |
+
rxn_or_smis: str | tuple[str, str],
|
109 |
+
*args,
|
110 |
+
keep_h: bool = False,
|
111 |
+
add_h: bool = False,
|
112 |
+
ignore_chirality: bool = False,
|
113 |
+
**kwargs,
|
114 |
+
) -> _ReactionDatapointMixin:
|
115 |
+
match rxn_or_smis:
|
116 |
+
case str():
|
117 |
+
rct_smi, agt_smi, pdt_smi = rxn_or_smis.split(">")
|
118 |
+
rct_smi = f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi
|
119 |
+
name = rxn_or_smis
|
120 |
+
case tuple():
|
121 |
+
rct_smi, pdt_smi = rxn_or_smis
|
122 |
+
name = ">>".join(rxn_or_smis)
|
123 |
+
case _:
|
124 |
+
raise TypeError(
|
125 |
+
"Must provide either a reaction SMARTS string or a tuple of reactant and"
|
126 |
+
" a product SMILES strings!"
|
127 |
+
)
|
128 |
+
|
129 |
+
rct = make_mol(rct_smi, keep_h, add_h, ignore_chirality)
|
130 |
+
pdt = make_mol(pdt_smi, keep_h, add_h, ignore_chirality)
|
131 |
+
|
132 |
+
kwargs["name"] = name if "name" not in kwargs else kwargs["name"]
|
133 |
+
|
134 |
+
return cls(rct, pdt, *args, **kwargs)
|
135 |
+
|
136 |
+
|
137 |
+
@dataclass
|
138 |
+
class ReactionDatapoint(_DatapointMixin, _ReactionDatapointMixin):
|
139 |
+
"""A :class:`ReactionDatapoint` contains a single reaction and its associated features and targets."""
|
140 |
+
|
141 |
+
def __post_init__(self):
|
142 |
+
if self.rct is None:
|
143 |
+
raise ValueError("Reactant cannot be `None`!")
|
144 |
+
if self.pdt is None:
|
145 |
+
raise ValueError("Product cannot be `None`!")
|
146 |
+
|
147 |
+
return super().__post_init__()
|
148 |
+
|
149 |
+
def __len__(self) -> int:
|
150 |
+
return 2
|
chemprop-updated/chemprop/data/datasets.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
from functools import cached_property
|
3 |
+
from typing import NamedTuple, TypeAlias
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from numpy.typing import ArrayLike
|
7 |
+
from rdkit import Chem
|
8 |
+
from rdkit.Chem import Mol
|
9 |
+
from sklearn.preprocessing import StandardScaler
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
|
12 |
+
from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
|
13 |
+
from chemprop.data.molgraph import MolGraph
|
14 |
+
from chemprop.featurizers.base import Featurizer
|
15 |
+
from chemprop.featurizers.molgraph import CGRFeaturizer, SimpleMoleculeMolGraphFeaturizer
|
16 |
+
from chemprop.featurizers.molgraph.cache import MolGraphCache, MolGraphCacheOnTheFly
|
17 |
+
from chemprop.types import Rxn
|
18 |
+
|
19 |
+
|
20 |
+
class Datum(NamedTuple):
|
21 |
+
"""a singular training data point"""
|
22 |
+
|
23 |
+
mg: MolGraph
|
24 |
+
V_d: np.ndarray | None
|
25 |
+
x_d: np.ndarray | None
|
26 |
+
y: np.ndarray | None
|
27 |
+
weight: float
|
28 |
+
lt_mask: np.ndarray | None
|
29 |
+
gt_mask: np.ndarray | None
|
30 |
+
|
31 |
+
|
32 |
+
MolGraphDataset: TypeAlias = Dataset[Datum]
|
33 |
+
|
34 |
+
|
35 |
+
class _MolGraphDatasetMixin:
|
36 |
+
def __len__(self) -> int:
|
37 |
+
return len(self.data)
|
38 |
+
|
39 |
+
@cached_property
|
40 |
+
def _Y(self) -> np.ndarray:
|
41 |
+
"""the raw targets of the dataset"""
|
42 |
+
return np.array([d.y for d in self.data], float)
|
43 |
+
|
44 |
+
@property
|
45 |
+
def Y(self) -> np.ndarray:
|
46 |
+
"""the (scaled) targets of the dataset"""
|
47 |
+
return self.__Y
|
48 |
+
|
49 |
+
@Y.setter
|
50 |
+
def Y(self, Y: ArrayLike):
|
51 |
+
self._validate_attribute(Y, "targets")
|
52 |
+
|
53 |
+
self.__Y = np.array(Y, float)
|
54 |
+
|
55 |
+
@cached_property
|
56 |
+
def _X_d(self) -> np.ndarray:
|
57 |
+
"""the raw extra descriptors of the dataset"""
|
58 |
+
return np.array([d.x_d for d in self.data])
|
59 |
+
|
60 |
+
@property
|
61 |
+
def X_d(self) -> np.ndarray:
|
62 |
+
"""the (scaled) extra descriptors of the dataset"""
|
63 |
+
return self.__X_d
|
64 |
+
|
65 |
+
@X_d.setter
|
66 |
+
def X_d(self, X_d: ArrayLike):
|
67 |
+
self._validate_attribute(X_d, "extra descriptors")
|
68 |
+
|
69 |
+
self.__X_d = np.array(X_d)
|
70 |
+
|
71 |
+
@property
|
72 |
+
def weights(self) -> np.ndarray:
|
73 |
+
return np.array([d.weight for d in self.data])
|
74 |
+
|
75 |
+
@property
|
76 |
+
def gt_mask(self) -> np.ndarray:
|
77 |
+
return np.array([d.gt_mask for d in self.data])
|
78 |
+
|
79 |
+
@property
|
80 |
+
def lt_mask(self) -> np.ndarray:
|
81 |
+
return np.array([d.lt_mask for d in self.data])
|
82 |
+
|
83 |
+
@property
|
84 |
+
def t(self) -> int | None:
|
85 |
+
return self.data[0].t if len(self.data) > 0 else None
|
86 |
+
|
87 |
+
@property
|
88 |
+
def d_xd(self) -> int:
|
89 |
+
"""the extra molecule descriptor dimension, if any"""
|
90 |
+
return 0 if self.X_d[0] is None else self.X_d.shape[1]
|
91 |
+
|
92 |
+
@property
|
93 |
+
def names(self) -> list[str]:
|
94 |
+
return [d.name for d in self.data]
|
95 |
+
|
96 |
+
def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler:
|
97 |
+
"""Normalizes the targets of this dataset using a :obj:`StandardScaler`
|
98 |
+
|
99 |
+
The :obj:`StandardScaler` subtracts the mean and divides by the standard deviation for
|
100 |
+
each task independently. NOTE: This should only be used for regression datasets.
|
101 |
+
|
102 |
+
Returns
|
103 |
+
-------
|
104 |
+
StandardScaler
|
105 |
+
a scaler fit to the targets.
|
106 |
+
"""
|
107 |
+
|
108 |
+
if scaler is None:
|
109 |
+
scaler = StandardScaler().fit(self._Y)
|
110 |
+
|
111 |
+
self.Y = scaler.transform(self._Y)
|
112 |
+
|
113 |
+
return scaler
|
114 |
+
|
115 |
+
def normalize_inputs(
|
116 |
+
self, key: str = "X_d", scaler: StandardScaler | None = None
|
117 |
+
) -> StandardScaler:
|
118 |
+
VALID_KEYS = {"X_d"}
|
119 |
+
if key not in VALID_KEYS:
|
120 |
+
raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}")
|
121 |
+
|
122 |
+
X = self.X_d if self.X_d[0] is not None else None
|
123 |
+
|
124 |
+
if X is None:
|
125 |
+
return scaler
|
126 |
+
|
127 |
+
if scaler is None:
|
128 |
+
scaler = StandardScaler().fit(X)
|
129 |
+
|
130 |
+
self.X_d = scaler.transform(X)
|
131 |
+
|
132 |
+
return scaler
|
133 |
+
|
134 |
+
def reset(self):
|
135 |
+
"""Reset the atom and bond features; atom and extra descriptors; and targets of each
|
136 |
+
datapoint to their initial, unnormalized values."""
|
137 |
+
self.__Y = self._Y
|
138 |
+
self.__X_d = self._X_d
|
139 |
+
|
140 |
+
def _validate_attribute(self, X: np.ndarray, label: str):
|
141 |
+
if not len(self.data) == len(X):
|
142 |
+
raise ValueError(
|
143 |
+
f"number of molecules ({len(self.data)}) and {label} ({len(X)}) "
|
144 |
+
"must have same length!"
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
@dataclass
|
149 |
+
class MoleculeDataset(_MolGraphDatasetMixin, MolGraphDataset):
|
150 |
+
"""A :class:`MoleculeDataset` composed of :class:`MoleculeDatapoint`\s
|
151 |
+
|
152 |
+
A :class:`MoleculeDataset` produces featurized data for input to a
|
153 |
+
:class:`MPNN` model. Typically, data featurization is performed on-the-fly
|
154 |
+
and parallelized across multiple workers via the :class:`~torch.utils.data
|
155 |
+
DataLoader` class. However, for small datasets, it may be more efficient to
|
156 |
+
featurize the data in advance and cache the results. This can be done by
|
157 |
+
setting ``MoleculeDataset.cache=True``.
|
158 |
+
|
159 |
+
Parameters
|
160 |
+
----------
|
161 |
+
data : Iterable[MoleculeDatapoint]
|
162 |
+
the data from which to create a dataset
|
163 |
+
featurizer : MoleculeFeaturizer
|
164 |
+
the featurizer with which to generate MolGraphs of the molecules
|
165 |
+
"""
|
166 |
+
|
167 |
+
data: list[MoleculeDatapoint]
|
168 |
+
featurizer: Featurizer[Mol, MolGraph] = field(default_factory=SimpleMoleculeMolGraphFeaturizer)
|
169 |
+
|
170 |
+
def __post_init__(self):
|
171 |
+
if self.data is None:
|
172 |
+
raise ValueError("Data cannot be None!")
|
173 |
+
|
174 |
+
self.reset()
|
175 |
+
self.cache = False
|
176 |
+
|
177 |
+
def __getitem__(self, idx: int) -> Datum:
|
178 |
+
d = self.data[idx]
|
179 |
+
mg = self.mg_cache[idx]
|
180 |
+
|
181 |
+
# Assign the SMILES string to the MolGraph
|
182 |
+
mg_with_name = MolGraph(
|
183 |
+
V=mg.V,
|
184 |
+
E=mg.E,
|
185 |
+
edge_index=mg.edge_index,
|
186 |
+
rev_edge_index=mg.rev_edge_index,
|
187 |
+
name=d.name # Assign the SMILES string
|
188 |
+
)
|
189 |
+
|
190 |
+
return Datum(
|
191 |
+
mg=mg_with_name, # Use the updated MolGraph
|
192 |
+
V_d=self.V_ds[idx],
|
193 |
+
x_d=self.X_d[idx],
|
194 |
+
y=self.Y[idx],
|
195 |
+
weight=d.weight,
|
196 |
+
lt_mask=d.lt_mask,
|
197 |
+
gt_mask=d.gt_mask,
|
198 |
+
)
|
199 |
+
@property
|
200 |
+
def cache(self) -> bool:
|
201 |
+
return self.__cache
|
202 |
+
|
203 |
+
@cache.setter
|
204 |
+
def cache(self, cache: bool = False):
|
205 |
+
self.__cache = cache
|
206 |
+
self._init_cache()
|
207 |
+
|
208 |
+
def _init_cache(self):
|
209 |
+
"""initialize the cache"""
|
210 |
+
self.mg_cache = (MolGraphCache if self.cache else MolGraphCacheOnTheFly)(
|
211 |
+
self.mols, self.V_fs, self.E_fs, self.featurizer
|
212 |
+
)
|
213 |
+
|
214 |
+
@property
|
215 |
+
def smiles(self) -> list[str]:
|
216 |
+
"""the SMILES strings associated with the dataset"""
|
217 |
+
return [Chem.MolToSmiles(d.mol) for d in self.data]
|
218 |
+
|
219 |
+
@property
|
220 |
+
def mols(self) -> list[Chem.Mol]:
|
221 |
+
"""the molecules associated with the dataset"""
|
222 |
+
return [d.mol for d in self.data]
|
223 |
+
|
224 |
+
@property
|
225 |
+
def _V_fs(self) -> list[np.ndarray]:
|
226 |
+
"""the raw atom features of the dataset"""
|
227 |
+
return [d.V_f for d in self.data]
|
228 |
+
|
229 |
+
@property
|
230 |
+
def V_fs(self) -> list[np.ndarray]:
|
231 |
+
"""the (scaled) atom descriptors of the dataset"""
|
232 |
+
return self.__V_fs
|
233 |
+
|
234 |
+
@V_fs.setter
|
235 |
+
def V_fs(self, V_fs: list[np.ndarray]):
|
236 |
+
"""the (scaled) atom features of the dataset"""
|
237 |
+
self._validate_attribute(V_fs, "atom features")
|
238 |
+
|
239 |
+
self.__V_fs = V_fs
|
240 |
+
self._init_cache()
|
241 |
+
|
242 |
+
@property
|
243 |
+
def _E_fs(self) -> list[np.ndarray]:
|
244 |
+
"""the raw bond features of the dataset"""
|
245 |
+
return [d.E_f for d in self.data]
|
246 |
+
|
247 |
+
@property
|
248 |
+
def E_fs(self) -> list[np.ndarray]:
|
249 |
+
"""the (scaled) bond features of the dataset"""
|
250 |
+
return self.__E_fs
|
251 |
+
|
252 |
+
@E_fs.setter
|
253 |
+
def E_fs(self, E_fs: list[np.ndarray]):
|
254 |
+
self._validate_attribute(E_fs, "bond features")
|
255 |
+
|
256 |
+
self.__E_fs = E_fs
|
257 |
+
self._init_cache()
|
258 |
+
|
259 |
+
@property
|
260 |
+
def _V_ds(self) -> list[np.ndarray]:
|
261 |
+
"""the raw atom descriptors of the dataset"""
|
262 |
+
return [d.V_d for d in self.data]
|
263 |
+
|
264 |
+
@property
|
265 |
+
def V_ds(self) -> list[np.ndarray]:
|
266 |
+
"""the (scaled) atom descriptors of the dataset"""
|
267 |
+
return self.__V_ds
|
268 |
+
|
269 |
+
@V_ds.setter
|
270 |
+
def V_ds(self, V_ds: list[np.ndarray]):
|
271 |
+
self._validate_attribute(V_ds, "atom descriptors")
|
272 |
+
|
273 |
+
self.__V_ds = V_ds
|
274 |
+
|
275 |
+
@property
|
276 |
+
def d_vf(self) -> int:
|
277 |
+
"""the extra atom feature dimension, if any"""
|
278 |
+
return 0 if self.V_fs[0] is None else self.V_fs[0].shape[1]
|
279 |
+
|
280 |
+
@property
|
281 |
+
def d_ef(self) -> int:
|
282 |
+
"""the extra bond feature dimension, if any"""
|
283 |
+
return 0 if self.E_fs[0] is None else self.E_fs[0].shape[1]
|
284 |
+
|
285 |
+
@property
|
286 |
+
def d_vd(self) -> int:
|
287 |
+
"""the extra atom descriptor dimension, if any"""
|
288 |
+
return 0 if self.V_ds[0] is None else self.V_ds[0].shape[1]
|
289 |
+
|
290 |
+
def normalize_inputs(
|
291 |
+
self, key: str = "X_d", scaler: StandardScaler | None = None
|
292 |
+
) -> StandardScaler:
|
293 |
+
VALID_KEYS = {"X_d", "V_f", "E_f", "V_d"}
|
294 |
+
|
295 |
+
match key:
|
296 |
+
case "X_d":
|
297 |
+
X = None if self.d_xd == 0 else self.X_d
|
298 |
+
case "V_f":
|
299 |
+
X = None if self.d_vf == 0 else np.concatenate(self.V_fs, axis=0)
|
300 |
+
case "E_f":
|
301 |
+
X = None if self.d_ef == 0 else np.concatenate(self.E_fs, axis=0)
|
302 |
+
case "V_d":
|
303 |
+
X = None if self.d_vd == 0 else np.concatenate(self.V_ds, axis=0)
|
304 |
+
case _:
|
305 |
+
raise ValueError(f"Invalid feature key! got: {key}. expected one of: {VALID_KEYS}")
|
306 |
+
|
307 |
+
if X is None:
|
308 |
+
return scaler
|
309 |
+
|
310 |
+
if scaler is None:
|
311 |
+
scaler = StandardScaler().fit(X)
|
312 |
+
|
313 |
+
match key:
|
314 |
+
case "X_d":
|
315 |
+
self.X_d = scaler.transform(X)
|
316 |
+
case "V_f":
|
317 |
+
self.V_fs = [scaler.transform(V_f) if V_f.size > 0 else V_f for V_f in self.V_fs]
|
318 |
+
case "E_f":
|
319 |
+
self.E_fs = [scaler.transform(E_f) if E_f.size > 0 else E_f for E_f in self.E_fs]
|
320 |
+
case "V_d":
|
321 |
+
self.V_ds = [scaler.transform(V_d) if V_d.size > 0 else V_d for V_d in self.V_ds]
|
322 |
+
case _:
|
323 |
+
raise RuntimeError("unreachable code reached!")
|
324 |
+
|
325 |
+
return scaler
|
326 |
+
|
327 |
+
def reset(self):
|
328 |
+
"""Reset the atom and bond features; atom and extra descriptors; and targets of each
|
329 |
+
datapoint to their initial, unnormalized values."""
|
330 |
+
super().reset()
|
331 |
+
self.__V_fs = self._V_fs
|
332 |
+
self.__E_fs = self._E_fs
|
333 |
+
self.__V_ds = self._V_ds
|
334 |
+
|
335 |
+
|
336 |
+
@dataclass
|
337 |
+
class ReactionDataset(_MolGraphDatasetMixin, MolGraphDataset):
|
338 |
+
"""A :class:`ReactionDataset` composed of :class:`ReactionDatapoint`\s
|
339 |
+
|
340 |
+
.. note::
|
341 |
+
The featurized data provided by this class may be cached, simlar to a
|
342 |
+
:class:`MoleculeDataset`. To enable the cache, set ``ReactionDataset
|
343 |
+
cache=True``.
|
344 |
+
"""
|
345 |
+
|
346 |
+
data: list[ReactionDatapoint]
|
347 |
+
"""the dataset from which to load"""
|
348 |
+
featurizer: Featurizer[Rxn, MolGraph] = field(default_factory=CGRFeaturizer)
|
349 |
+
"""the featurizer with which to generate MolGraphs of the input"""
|
350 |
+
|
351 |
+
def __post_init__(self):
|
352 |
+
if self.data is None:
|
353 |
+
raise ValueError("Data cannot be None!")
|
354 |
+
|
355 |
+
self.reset()
|
356 |
+
self.cache = False
|
357 |
+
|
358 |
+
@property
|
359 |
+
def cache(self) -> bool:
|
360 |
+
return self.__cache
|
361 |
+
|
362 |
+
@cache.setter
|
363 |
+
def cache(self, cache: bool = False):
|
364 |
+
self.__cache = cache
|
365 |
+
self.mg_cache = (MolGraphCache if cache else MolGraphCacheOnTheFly)(
|
366 |
+
self.mols, [None] * len(self), [None] * len(self), self.featurizer
|
367 |
+
)
|
368 |
+
|
369 |
+
def __getitem__(self, idx: int) -> Datum:
|
370 |
+
d = self.data[idx]
|
371 |
+
mg = self.mg_cache[idx]
|
372 |
+
|
373 |
+
return Datum(mg, None, self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask)
|
374 |
+
|
375 |
+
@property
|
376 |
+
def smiles(self) -> list[tuple]:
|
377 |
+
return [(Chem.MolToSmiles(d.rct), Chem.MolToSmiles(d.pdt)) for d in self.data]
|
378 |
+
|
379 |
+
@property
|
380 |
+
def mols(self) -> list[Rxn]:
|
381 |
+
return [(d.rct, d.pdt) for d in self.data]
|
382 |
+
|
383 |
+
@property
|
384 |
+
def d_vf(self) -> int:
|
385 |
+
return 0
|
386 |
+
|
387 |
+
@property
|
388 |
+
def d_ef(self) -> int:
|
389 |
+
return 0
|
390 |
+
|
391 |
+
@property
|
392 |
+
def d_vd(self) -> int:
|
393 |
+
return 0
|
394 |
+
|
395 |
+
|
396 |
+
@dataclass(repr=False, eq=False)
|
397 |
+
class MulticomponentDataset(_MolGraphDatasetMixin, Dataset):
|
398 |
+
"""A :class:`MulticomponentDataset` is a :class:`Dataset` composed of parallel
|
399 |
+
:class:`MoleculeDatasets` and :class:`ReactionDataset`\s"""
|
400 |
+
|
401 |
+
datasets: list[MoleculeDataset | ReactionDataset]
|
402 |
+
"""the parallel datasets"""
|
403 |
+
|
404 |
+
def __post_init__(self):
|
405 |
+
sizes = [len(dset) for dset in self.datasets]
|
406 |
+
if not all(sizes[0] == size for size in sizes[1:]):
|
407 |
+
raise ValueError(f"Datasets must have all same length! got: {sizes}")
|
408 |
+
|
409 |
+
def __len__(self) -> int:
|
410 |
+
return len(self.datasets[0])
|
411 |
+
|
412 |
+
@property
|
413 |
+
def n_components(self) -> int:
|
414 |
+
return len(self.datasets)
|
415 |
+
|
416 |
+
def __getitem__(self, idx: int) -> list[Datum]:
|
417 |
+
return [dset[idx] for dset in self.datasets]
|
418 |
+
|
419 |
+
@property
|
420 |
+
def smiles(self) -> list[list[str]]:
|
421 |
+
return list(zip(*[dset.smiles for dset in self.datasets]))
|
422 |
+
|
423 |
+
@property
|
424 |
+
def names(self) -> list[list[str]]:
|
425 |
+
return list(zip(*[dset.names for dset in self.datasets]))
|
426 |
+
|
427 |
+
@property
|
428 |
+
def mols(self) -> list[list[Chem.Mol]]:
|
429 |
+
return list(zip(*[dset.mols for dset in self.datasets]))
|
430 |
+
|
431 |
+
def normalize_targets(self, scaler: StandardScaler | None = None) -> StandardScaler:
|
432 |
+
return self.datasets[0].normalize_targets(scaler)
|
433 |
+
|
434 |
+
def normalize_inputs(
|
435 |
+
self, key: str = "X_d", scaler: list[StandardScaler] | None = None
|
436 |
+
) -> list[StandardScaler]:
|
437 |
+
RXN_VALID_KEYS = {"X_d"}
|
438 |
+
match scaler:
|
439 |
+
case None:
|
440 |
+
return [
|
441 |
+
dset.normalize_inputs(key)
|
442 |
+
if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS
|
443 |
+
else None
|
444 |
+
for dset in self.datasets
|
445 |
+
]
|
446 |
+
case _:
|
447 |
+
assert len(scaler) == len(
|
448 |
+
self.datasets
|
449 |
+
), "Number of scalers must match number of datasets!"
|
450 |
+
|
451 |
+
return [
|
452 |
+
dset.normalize_inputs(key, s)
|
453 |
+
if isinstance(dset, MoleculeDataset) or key in RXN_VALID_KEYS
|
454 |
+
else None
|
455 |
+
for dset, s in zip(self.datasets, scaler)
|
456 |
+
]
|
457 |
+
|
458 |
+
def reset(self):
|
459 |
+
return [dset.reset() for dset in self.datasets]
|
460 |
+
|
461 |
+
@property
|
462 |
+
def d_xd(self) -> list[int]:
|
463 |
+
return self.datasets[0].d_xd
|
464 |
+
|
465 |
+
@property
|
466 |
+
def d_vf(self) -> list[int]:
|
467 |
+
return sum(dset.d_vf for dset in self.datasets)
|
468 |
+
|
469 |
+
@property
|
470 |
+
def d_ef(self) -> list[int]:
|
471 |
+
return sum(dset.d_ef for dset in self.datasets)
|
472 |
+
|
473 |
+
@property
|
474 |
+
def d_vd(self) -> list[int]:
|
475 |
+
return sum(dset.d_vd for dset in self.datasets)
|
chemprop-updated/chemprop/data/molgraph.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import NamedTuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
|
6 |
+
class MolGraph(NamedTuple):
|
7 |
+
"""A :class:`MolGraph` represents the graph featurization of a molecule."""
|
8 |
+
|
9 |
+
V: np.ndarray
|
10 |
+
"""an array of shape ``V x d_v`` containing the atom features of the molecule"""
|
11 |
+
E: np.ndarray
|
12 |
+
"""an array of shape ``E x d_e`` containing the bond features of the molecule"""
|
13 |
+
edge_index: np.ndarray
|
14 |
+
"""an array of shape ``2 x E`` containing the edges of the graph in COO format"""
|
15 |
+
rev_edge_index: np.ndarray
|
16 |
+
"""A array of shape ``E`` that maps from an edge index to the index of the source of the reverse edge in :attr:`edge_index` attribute."""
|
17 |
+
name: str | None = None # Add SMILES string as an optional attribute
|
chemprop-updated/chemprop/data/samplers.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from itertools import chain
|
2 |
+
from typing import Iterator, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from torch.utils.data import Sampler
|
6 |
+
|
7 |
+
|
8 |
+
class SeededSampler(Sampler):
|
9 |
+
"""A :class`SeededSampler` is a class for iterating through a dataset in a randomly seeded
|
10 |
+
fashion"""
|
11 |
+
|
12 |
+
def __init__(self, N: int, seed: int):
|
13 |
+
if seed is None:
|
14 |
+
raise ValueError("arg 'seed' was `None`! A SeededSampler must be seeded!")
|
15 |
+
|
16 |
+
self.idxs = np.arange(N)
|
17 |
+
self.rg = np.random.default_rng(seed)
|
18 |
+
|
19 |
+
def __iter__(self) -> Iterator[int]:
|
20 |
+
"""an iterator over indices to sample."""
|
21 |
+
self.rg.shuffle(self.idxs)
|
22 |
+
|
23 |
+
return iter(self.idxs)
|
24 |
+
|
25 |
+
def __len__(self) -> int:
|
26 |
+
"""the number of indices that will be sampled."""
|
27 |
+
return len(self.idxs)
|
28 |
+
|
29 |
+
|
30 |
+
class ClassBalanceSampler(Sampler):
|
31 |
+
"""A :class:`ClassBalanceSampler` samples data from a :class:`MolGraphDataset` such that
|
32 |
+
positive and negative classes are equally sampled
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
dataset : MolGraphDataset
|
37 |
+
the dataset from which to sample
|
38 |
+
seed : int
|
39 |
+
the random seed to use for shuffling (only used when `shuffle` is `True`)
|
40 |
+
shuffle : bool, default=False
|
41 |
+
whether to shuffle the data during sampling
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(self, Y: np.ndarray, seed: Optional[int] = None, shuffle: bool = False):
|
45 |
+
self.shuffle = shuffle
|
46 |
+
self.rg = np.random.default_rng(seed)
|
47 |
+
|
48 |
+
idxs = np.arange(len(Y))
|
49 |
+
actives = Y.any(1)
|
50 |
+
|
51 |
+
self.pos_idxs = idxs[actives]
|
52 |
+
self.neg_idxs = idxs[~actives]
|
53 |
+
|
54 |
+
self.length = 2 * min(len(self.pos_idxs), len(self.neg_idxs))
|
55 |
+
|
56 |
+
def __iter__(self) -> Iterator[int]:
|
57 |
+
"""an iterator over indices to sample."""
|
58 |
+
if self.shuffle:
|
59 |
+
self.rg.shuffle(self.pos_idxs)
|
60 |
+
self.rg.shuffle(self.neg_idxs)
|
61 |
+
|
62 |
+
return chain(*zip(self.pos_idxs, self.neg_idxs))
|
63 |
+
|
64 |
+
def __len__(self) -> int:
|
65 |
+
"""the number of indices that will be sampled."""
|
66 |
+
return self.length
|
chemprop-updated/chemprop/data/splitting.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections.abc import Iterable, Sequence
|
2 |
+
import copy
|
3 |
+
from enum import auto
|
4 |
+
import logging
|
5 |
+
|
6 |
+
from astartes import train_test_split, train_val_test_split
|
7 |
+
from astartes.molecules import train_test_split_molecules, train_val_test_split_molecules
|
8 |
+
import numpy as np
|
9 |
+
from rdkit import Chem
|
10 |
+
|
11 |
+
from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
|
12 |
+
from chemprop.utils.utils import EnumMapping
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
Datapoints = Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint]
|
17 |
+
MulticomponentDatapoints = Sequence[Datapoints]
|
18 |
+
|
19 |
+
|
20 |
+
class SplitType(EnumMapping):
|
21 |
+
SCAFFOLD_BALANCED = auto()
|
22 |
+
RANDOM_WITH_REPEATED_SMILES = auto()
|
23 |
+
RANDOM = auto()
|
24 |
+
KENNARD_STONE = auto()
|
25 |
+
KMEANS = auto()
|
26 |
+
|
27 |
+
|
28 |
+
def make_split_indices(
|
29 |
+
mols: Sequence[Chem.Mol],
|
30 |
+
split: SplitType | str = "random",
|
31 |
+
sizes: tuple[float, float, float] = (0.8, 0.1, 0.1),
|
32 |
+
seed: int = 0,
|
33 |
+
num_replicates: int = 1,
|
34 |
+
num_folds: None = None,
|
35 |
+
) -> tuple[list[list[int]], ...]:
|
36 |
+
"""Splits data into training, validation, and test splits.
|
37 |
+
|
38 |
+
Parameters
|
39 |
+
----------
|
40 |
+
mols : Sequence[Chem.Mol]
|
41 |
+
Sequence of RDKit molecules to use for structure based splitting
|
42 |
+
split : SplitType | str, optional
|
43 |
+
Split type, one of ~chemprop.data.utils.SplitType, by default "random"
|
44 |
+
sizes : tuple[float, float, float], optional
|
45 |
+
3-tuple with the proportions of data in the train, validation, and test sets, by default
|
46 |
+
(0.8, 0.1, 0.1). Set the middle value to 0 for a two way split.
|
47 |
+
seed : int, optional
|
48 |
+
The random seed passed to astartes, by default 0
|
49 |
+
num_replicates : int, optional
|
50 |
+
Number of replicates, by default 1
|
51 |
+
num_folds : None, optional
|
52 |
+
This argument was removed in v2.1 - use `num_replicates` instead.
|
53 |
+
|
54 |
+
Returns
|
55 |
+
-------
|
56 |
+
tuple[list[list[int]], ...]
|
57 |
+
2- or 3-member tuple containing num_replicates length lists of training, validation, and testing indexes.
|
58 |
+
|
59 |
+
.. important::
|
60 |
+
Validation may or may not be present
|
61 |
+
|
62 |
+
Raises
|
63 |
+
------
|
64 |
+
ValueError
|
65 |
+
Requested split sizes tuple not of length 3
|
66 |
+
ValueError
|
67 |
+
Unsupported split method requested
|
68 |
+
"""
|
69 |
+
if num_folds is not None:
|
70 |
+
raise RuntimeError("This argument was removed in v2.1 - use `num_replicates` instead.")
|
71 |
+
if num_replicates == 1:
|
72 |
+
logger.warning(
|
73 |
+
"The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)"
|
74 |
+
)
|
75 |
+
if (num_splits := len(sizes)) != 3:
|
76 |
+
raise ValueError(
|
77 |
+
f"Specify sizes for train, validation, and test (got {num_splits} values)."
|
78 |
+
)
|
79 |
+
# typically include a validation set
|
80 |
+
include_val = True
|
81 |
+
split_fun = train_val_test_split
|
82 |
+
mol_split_fun = train_val_test_split_molecules
|
83 |
+
# default sampling arguments for astartes sampler
|
84 |
+
astartes_kwargs = dict(
|
85 |
+
train_size=sizes[0], test_size=sizes[2], return_indices=True, random_state=seed
|
86 |
+
)
|
87 |
+
# if no validation set, reassign the splitting functions
|
88 |
+
if sizes[1] == 0.0:
|
89 |
+
include_val = False
|
90 |
+
split_fun = train_test_split
|
91 |
+
mol_split_fun = train_test_split_molecules
|
92 |
+
else:
|
93 |
+
astartes_kwargs["val_size"] = sizes[1]
|
94 |
+
|
95 |
+
n_datapoints = len(mols)
|
96 |
+
train_replicates, val_replicates, test_replicates = [], [], []
|
97 |
+
for _ in range(num_replicates):
|
98 |
+
train, val, test = None, None, None
|
99 |
+
match SplitType.get(split):
|
100 |
+
case SplitType.SCAFFOLD_BALANCED:
|
101 |
+
mols_without_atommaps = []
|
102 |
+
for mol in mols:
|
103 |
+
copied_mol = copy.deepcopy(mol)
|
104 |
+
for atom in copied_mol.GetAtoms():
|
105 |
+
atom.SetAtomMapNum(0)
|
106 |
+
mols_without_atommaps.append(copied_mol)
|
107 |
+
result = mol_split_fun(
|
108 |
+
np.array(mols_without_atommaps), sampler="scaffold", **astartes_kwargs
|
109 |
+
)
|
110 |
+
train, val, test = _unpack_astartes_result(result, include_val)
|
111 |
+
|
112 |
+
# Use to constrain data with the same smiles go in the same split.
|
113 |
+
case SplitType.RANDOM_WITH_REPEATED_SMILES:
|
114 |
+
# get two arrays: one of all the smiles strings, one of just the unique
|
115 |
+
all_smiles = np.array([Chem.MolToSmiles(mol) for mol in mols])
|
116 |
+
unique_smiles = np.unique(all_smiles)
|
117 |
+
|
118 |
+
# save a mapping of smiles -> all the indices that it appeared at
|
119 |
+
smiles_indices = {}
|
120 |
+
for smiles in unique_smiles:
|
121 |
+
smiles_indices[smiles] = np.where(all_smiles == smiles)[0].tolist()
|
122 |
+
|
123 |
+
# randomly split the unique smiles
|
124 |
+
result = split_fun(
|
125 |
+
np.arange(len(unique_smiles)), sampler="random", **astartes_kwargs
|
126 |
+
)
|
127 |
+
train_idxs, val_idxs, test_idxs = _unpack_astartes_result(result, include_val)
|
128 |
+
|
129 |
+
# convert these to the 'actual' indices from the original list using the dict we made
|
130 |
+
train = sum((smiles_indices[unique_smiles[i]] for i in train_idxs), [])
|
131 |
+
val = sum((smiles_indices[unique_smiles[j]] for j in val_idxs), [])
|
132 |
+
test = sum((smiles_indices[unique_smiles[k]] for k in test_idxs), [])
|
133 |
+
|
134 |
+
case SplitType.RANDOM:
|
135 |
+
result = split_fun(np.arange(n_datapoints), sampler="random", **astartes_kwargs)
|
136 |
+
train, val, test = _unpack_astartes_result(result, include_val)
|
137 |
+
|
138 |
+
case SplitType.KENNARD_STONE:
|
139 |
+
result = mol_split_fun(
|
140 |
+
np.array(mols),
|
141 |
+
sampler="kennard_stone",
|
142 |
+
hopts=dict(metric="jaccard"),
|
143 |
+
fingerprint="morgan_fingerprint",
|
144 |
+
fprints_hopts=dict(n_bits=2048),
|
145 |
+
**astartes_kwargs,
|
146 |
+
)
|
147 |
+
train, val, test = _unpack_astartes_result(result, include_val)
|
148 |
+
|
149 |
+
case SplitType.KMEANS:
|
150 |
+
result = mol_split_fun(
|
151 |
+
np.array(mols),
|
152 |
+
sampler="kmeans",
|
153 |
+
hopts=dict(metric="jaccard"),
|
154 |
+
fingerprint="morgan_fingerprint",
|
155 |
+
fprints_hopts=dict(n_bits=2048),
|
156 |
+
**astartes_kwargs,
|
157 |
+
)
|
158 |
+
train, val, test = _unpack_astartes_result(result, include_val)
|
159 |
+
|
160 |
+
case _:
|
161 |
+
raise RuntimeError("Unreachable code reached!")
|
162 |
+
train_replicates.append(train)
|
163 |
+
val_replicates.append(val)
|
164 |
+
test_replicates.append(test)
|
165 |
+
astartes_kwargs["random_state"] += 1
|
166 |
+
return train_replicates, val_replicates, test_replicates
|
167 |
+
|
168 |
+
|
169 |
+
def _unpack_astartes_result(
|
170 |
+
result: tuple, include_val: bool
|
171 |
+
) -> tuple[list[int], list[int], list[int]]:
|
172 |
+
"""Helper function to partition input data based on output of astartes sampler
|
173 |
+
|
174 |
+
Parameters
|
175 |
+
-----------
|
176 |
+
result: tuple
|
177 |
+
Output from call to astartes containing the split indices
|
178 |
+
include_val: bool
|
179 |
+
True if a validation set is included, False otherwise.
|
180 |
+
|
181 |
+
Returns
|
182 |
+
---------
|
183 |
+
train: list[int]
|
184 |
+
val: list[int]
|
185 |
+
.. important::
|
186 |
+
validation possibly empty
|
187 |
+
test: list[int]
|
188 |
+
"""
|
189 |
+
train_idxs, val_idxs, test_idxs = [], [], []
|
190 |
+
# astartes returns a set of lists containing the data, clusters (if applicable)
|
191 |
+
# and indices (always last), so we pull out the indices
|
192 |
+
if include_val:
|
193 |
+
train_idxs, val_idxs, test_idxs = result[-3], result[-2], result[-1]
|
194 |
+
else:
|
195 |
+
train_idxs, test_idxs = result[-2], result[-1]
|
196 |
+
return list(train_idxs), list(val_idxs), list(test_idxs)
|
197 |
+
|
198 |
+
|
199 |
+
def split_data_by_indices(
|
200 |
+
data: Datapoints | MulticomponentDatapoints,
|
201 |
+
train_indices: Iterable[Iterable[int]] | None = None,
|
202 |
+
val_indices: Iterable[Iterable[int]] | None = None,
|
203 |
+
test_indices: Iterable[Iterable[int]] | None = None,
|
204 |
+
):
|
205 |
+
"""Splits data into training, validation, and test groups based on split indices given."""
|
206 |
+
|
207 |
+
train_data = _splitter_helper(data, train_indices)
|
208 |
+
val_data = _splitter_helper(data, val_indices)
|
209 |
+
test_data = _splitter_helper(data, test_indices)
|
210 |
+
|
211 |
+
return train_data, val_data, test_data
|
212 |
+
|
213 |
+
|
214 |
+
def _splitter_helper(data, indices):
|
215 |
+
if indices is None:
|
216 |
+
return None
|
217 |
+
|
218 |
+
if isinstance(data[0], (MoleculeDatapoint, ReactionDatapoint)):
|
219 |
+
datapoints = data
|
220 |
+
idxss = indices
|
221 |
+
return [[datapoints[idx] for idx in idxs] for idxs in idxss]
|
222 |
+
else:
|
223 |
+
datapointss = data
|
224 |
+
idxss = indices
|
225 |
+
return [[[datapoints[idx] for idx in idxs] for datapoints in datapointss] for idxs in idxss]
|
chemprop-updated/chemprop/exceptions.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable
|
2 |
+
|
3 |
+
from chemprop.utils import pretty_shape
|
4 |
+
|
5 |
+
|
6 |
+
class InvalidShapeError(ValueError):
|
7 |
+
def __init__(self, var_name: str, received: Iterable[int], expected: Iterable[int]):
|
8 |
+
message = (
|
9 |
+
f"arg '{var_name}' has incorrect shape! "
|
10 |
+
f"got: `{pretty_shape(received)}`. expected: `{pretty_shape(expected)}`"
|
11 |
+
)
|
12 |
+
super().__init__(message)
|
chemprop-updated/chemprop/features/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (1.46 kB). View file
|
|
chemprop-updated/chemprop/features/__pycache__/features_generators.cpython-37.pyc
ADDED
Binary file (5.71 kB). View file
|
|
chemprop-updated/chemprop/features/__pycache__/featurization.cpython-37.pyc
ADDED
Binary file (24.7 kB). View file
|
|
chemprop-updated/chemprop/features/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (4.86 kB). View file
|
|
chemprop-updated/chemprop/featurizers/__init__.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .atom import AtomFeatureMode, MultiHotAtomFeaturizer, get_multi_hot_atom_featurizer
|
2 |
+
from .base import Featurizer, GraphFeaturizer, S, T, VectorFeaturizer
|
3 |
+
from .bond import MultiHotBondFeaturizer
|
4 |
+
from .molecule import (
|
5 |
+
BinaryFeaturizerMixin,
|
6 |
+
CountFeaturizerMixin,
|
7 |
+
MoleculeFeaturizerRegistry,
|
8 |
+
MorganBinaryFeaturizer,
|
9 |
+
MorganCountFeaturizer,
|
10 |
+
MorganFeaturizerMixin,
|
11 |
+
RDKit2DFeaturizer,
|
12 |
+
V1RDKit2DFeaturizer,
|
13 |
+
V1RDKit2DNormalizedFeaturizer,
|
14 |
+
)
|
15 |
+
from .molgraph import (
|
16 |
+
CGRFeaturizer,
|
17 |
+
CondensedGraphOfReactionFeaturizer,
|
18 |
+
MolGraphCache,
|
19 |
+
MolGraphCacheFacade,
|
20 |
+
MolGraphCacheOnTheFly,
|
21 |
+
RxnMode,
|
22 |
+
SimpleMoleculeMolGraphFeaturizer,
|
23 |
+
)
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
"Featurizer",
|
27 |
+
"S",
|
28 |
+
"T",
|
29 |
+
"VectorFeaturizer",
|
30 |
+
"GraphFeaturizer",
|
31 |
+
"MultiHotAtomFeaturizer",
|
32 |
+
"AtomFeatureMode",
|
33 |
+
"get_multi_hot_atom_featurizer",
|
34 |
+
"MultiHotBondFeaturizer",
|
35 |
+
"MolGraphCacheFacade",
|
36 |
+
"MolGraphCache",
|
37 |
+
"MolGraphCacheOnTheFly",
|
38 |
+
"SimpleMoleculeMolGraphFeaturizer",
|
39 |
+
"CondensedGraphOfReactionFeaturizer",
|
40 |
+
"CGRFeaturizer",
|
41 |
+
"RxnMode",
|
42 |
+
"MoleculeFeaturizer",
|
43 |
+
"MorganFeaturizerMixin",
|
44 |
+
"BinaryFeaturizerMixin",
|
45 |
+
"CountFeaturizerMixin",
|
46 |
+
"MorganBinaryFeaturizer",
|
47 |
+
"MorganCountFeaturizer",
|
48 |
+
"RDKit2DFeaturizer",
|
49 |
+
"MoleculeFeaturizerRegistry",
|
50 |
+
"V1RDKit2DFeaturizer",
|
51 |
+
"V1RDKit2DNormalizedFeaturizer",
|
52 |
+
]
|
chemprop-updated/chemprop/featurizers/atom.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import auto
|
2 |
+
from typing import Sequence
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from rdkit.Chem.rdchem import Atom, HybridizationType
|
6 |
+
|
7 |
+
from chemprop.featurizers.base import VectorFeaturizer
|
8 |
+
from chemprop.utils.utils import EnumMapping
|
9 |
+
|
10 |
+
|
11 |
+
class MultiHotAtomFeaturizer(VectorFeaturizer[Atom]):
|
12 |
+
"""A :class:`MultiHotAtomFeaturizer` uses a multi-hot encoding to featurize atoms.
|
13 |
+
|
14 |
+
.. seealso::
|
15 |
+
The class provides three default parameterization schemes:
|
16 |
+
|
17 |
+
* :meth:`MultiHotAtomFeaturizer.v1`
|
18 |
+
* :meth:`MultiHotAtomFeaturizer.v2`
|
19 |
+
* :meth:`MultiHotAtomFeaturizer.organic`
|
20 |
+
|
21 |
+
The generated atom features are ordered as follows:
|
22 |
+
* atomic number
|
23 |
+
* degree
|
24 |
+
* formal charge
|
25 |
+
* chiral tag
|
26 |
+
* number of hydrogens
|
27 |
+
* hybridization
|
28 |
+
* aromaticity
|
29 |
+
* mass
|
30 |
+
|
31 |
+
.. important::
|
32 |
+
Each feature, except for aromaticity and mass, includes a pad for unknown values.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
atomic_nums : Sequence[int]
|
37 |
+
the choices for atom type denoted by atomic number. Ex: ``[4, 5, 6]`` for C, N and O.
|
38 |
+
degrees : Sequence[int]
|
39 |
+
the choices for number of bonds an atom is engaged in.
|
40 |
+
formal_charges : Sequence[int]
|
41 |
+
the choices for integer electronic charge assigned to an atom.
|
42 |
+
chiral_tags : Sequence[int]
|
43 |
+
the choices for an atom's chiral tag. See :class:`rdkit.Chem.rdchem.ChiralType` for possible integer values.
|
44 |
+
num_Hs : Sequence[int]
|
45 |
+
the choices for number of bonded hydrogen atoms.
|
46 |
+
hybridizations : Sequence[int]
|
47 |
+
the choices for an atom’s hybridization type. See :class:`rdkit.Chem.rdchem.HybridizationType` for possible integer values.
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
atomic_nums: Sequence[int],
|
53 |
+
degrees: Sequence[int],
|
54 |
+
formal_charges: Sequence[int],
|
55 |
+
chiral_tags: Sequence[int],
|
56 |
+
num_Hs: Sequence[int],
|
57 |
+
hybridizations: Sequence[int],
|
58 |
+
):
|
59 |
+
self.atomic_nums = {j: i for i, j in enumerate(atomic_nums)}
|
60 |
+
self.degrees = {i: i for i in degrees}
|
61 |
+
self.formal_charges = {j: i for i, j in enumerate(formal_charges)}
|
62 |
+
self.chiral_tags = {i: i for i in chiral_tags}
|
63 |
+
self.num_Hs = {i: i for i in num_Hs}
|
64 |
+
self.hybridizations = {ht: i for i, ht in enumerate(hybridizations)}
|
65 |
+
|
66 |
+
self._subfeats: list[dict] = [
|
67 |
+
self.atomic_nums,
|
68 |
+
self.degrees,
|
69 |
+
self.formal_charges,
|
70 |
+
self.chiral_tags,
|
71 |
+
self.num_Hs,
|
72 |
+
self.hybridizations,
|
73 |
+
]
|
74 |
+
subfeat_sizes = [
|
75 |
+
1 + len(self.atomic_nums),
|
76 |
+
1 + len(self.degrees),
|
77 |
+
1 + len(self.formal_charges),
|
78 |
+
1 + len(self.chiral_tags),
|
79 |
+
1 + len(self.num_Hs),
|
80 |
+
1 + len(self.hybridizations),
|
81 |
+
1,
|
82 |
+
1,
|
83 |
+
]
|
84 |
+
self.__size = sum(subfeat_sizes)
|
85 |
+
|
86 |
+
def __len__(self) -> int:
|
87 |
+
return self.__size
|
88 |
+
|
89 |
+
def __call__(self, a: Atom | None) -> np.ndarray:
|
90 |
+
x = np.zeros(self.__size)
|
91 |
+
|
92 |
+
if a is None:
|
93 |
+
return x
|
94 |
+
|
95 |
+
feats = [
|
96 |
+
a.GetAtomicNum(),
|
97 |
+
a.GetTotalDegree(),
|
98 |
+
a.GetFormalCharge(),
|
99 |
+
int(a.GetChiralTag()),
|
100 |
+
int(a.GetTotalNumHs()),
|
101 |
+
a.GetHybridization(),
|
102 |
+
]
|
103 |
+
i = 0
|
104 |
+
for feat, choices in zip(feats, self._subfeats):
|
105 |
+
j = choices.get(feat, len(choices))
|
106 |
+
x[i + j] = 1
|
107 |
+
i += len(choices) + 1
|
108 |
+
x[i] = int(a.GetIsAromatic())
|
109 |
+
x[i + 1] = 0.01 * a.GetMass()
|
110 |
+
|
111 |
+
return x
|
112 |
+
|
113 |
+
def num_only(self, a: Atom) -> np.ndarray:
|
114 |
+
"""featurize the atom by setting only the atomic number bit"""
|
115 |
+
x = np.zeros(len(self))
|
116 |
+
|
117 |
+
if a is None:
|
118 |
+
return x
|
119 |
+
|
120 |
+
i = self.atomic_nums.get(a.GetAtomicNum(), len(self.atomic_nums))
|
121 |
+
x[i] = 1
|
122 |
+
|
123 |
+
return x
|
124 |
+
|
125 |
+
@classmethod
|
126 |
+
def v1(cls, max_atomic_num: int = 100):
|
127 |
+
"""The original implementation used in Chemprop V1 [1]_, [2]_.
|
128 |
+
|
129 |
+
Parameters
|
130 |
+
----------
|
131 |
+
max_atomic_num : int, default=100
|
132 |
+
Include a bit for all atomic numbers in the interval :math:`[1, \mathtt{max\_atomic\_num}]`
|
133 |
+
|
134 |
+
References
|
135 |
+
-----------
|
136 |
+
.. [1] Yang, K.; Swanson, K.; Jin, W.; Coley, C.; Eiden, P.; Gao, H.; Guzman-Perez, A.; Hopper, T.;
|
137 |
+
Kelley, B.; Mathea, M.; Palmer, A. "Analyzing Learned Molecular Representations for Property Prediction."
|
138 |
+
J. Chem. Inf. Model. 2019, 59 (8), 3370–3388. https://doi.org/10.1021/acs.jcim.9b00237
|
139 |
+
.. [2] Heid, E.; Greenman, K.P.; Chung, Y.; Li, S.C.; Graff, D.E.; Vermeire, F.H.; Wu, H.; Green, W.H.; McGill,
|
140 |
+
C.J. "Chemprop: A machine learning package for chemical property prediction." J. Chem. Inf. Model. 2024,
|
141 |
+
64 (1), 9–17. https://doi.org/10.1021/acs.jcim.3c01250
|
142 |
+
"""
|
143 |
+
|
144 |
+
return cls(
|
145 |
+
atomic_nums=list(range(1, max_atomic_num + 1)),
|
146 |
+
degrees=list(range(6)),
|
147 |
+
formal_charges=[-1, -2, 1, 2, 0],
|
148 |
+
chiral_tags=list(range(4)),
|
149 |
+
num_Hs=list(range(5)),
|
150 |
+
hybridizations=[
|
151 |
+
HybridizationType.SP,
|
152 |
+
HybridizationType.SP2,
|
153 |
+
HybridizationType.SP3,
|
154 |
+
HybridizationType.SP3D,
|
155 |
+
HybridizationType.SP3D2,
|
156 |
+
],
|
157 |
+
)
|
158 |
+
|
159 |
+
@classmethod
|
160 |
+
def v2(cls):
|
161 |
+
"""An implementation that includes an atom type bit for all elements in the first four rows of the periodic table plus iodine."""
|
162 |
+
|
163 |
+
return cls(
|
164 |
+
atomic_nums=list(range(1, 37)) + [53],
|
165 |
+
degrees=list(range(6)),
|
166 |
+
formal_charges=[-1, -2, 1, 2, 0],
|
167 |
+
chiral_tags=list(range(4)),
|
168 |
+
num_Hs=list(range(5)),
|
169 |
+
hybridizations=[
|
170 |
+
HybridizationType.S,
|
171 |
+
HybridizationType.SP,
|
172 |
+
HybridizationType.SP2,
|
173 |
+
HybridizationType.SP2D,
|
174 |
+
HybridizationType.SP3,
|
175 |
+
HybridizationType.SP3D,
|
176 |
+
HybridizationType.SP3D2,
|
177 |
+
],
|
178 |
+
)
|
179 |
+
|
180 |
+
@classmethod
|
181 |
+
def organic(cls):
|
182 |
+
r"""A specific parameterization intended for use with organic or drug-like molecules.
|
183 |
+
|
184 |
+
This parameterization features:
|
185 |
+
1. includes an atomic number bit only for H, B, C, N, O, F, Si, P, S, Cl, Br, and I atoms
|
186 |
+
2. a hybridization bit for :math:`s, sp, sp^2` and :math:`sp^3` hybridizations.
|
187 |
+
"""
|
188 |
+
|
189 |
+
return cls(
|
190 |
+
atomic_nums=[1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 35, 53],
|
191 |
+
degrees=list(range(6)),
|
192 |
+
formal_charges=[-1, -2, 1, 2, 0],
|
193 |
+
chiral_tags=list(range(4)),
|
194 |
+
num_Hs=list(range(5)),
|
195 |
+
hybridizations=[
|
196 |
+
HybridizationType.S,
|
197 |
+
HybridizationType.SP,
|
198 |
+
HybridizationType.SP2,
|
199 |
+
HybridizationType.SP3,
|
200 |
+
],
|
201 |
+
)
|
202 |
+
|
203 |
+
|
204 |
+
class RIGRAtomFeaturizer(VectorFeaturizer[Atom]):
|
205 |
+
"""A :class:`RIGRAtomFeaturizer` uses a multi-hot encoding to featurize atoms using resonance-invariant features.
|
206 |
+
|
207 |
+
The generated atom features are ordered as follows:
|
208 |
+
* atomic number
|
209 |
+
* degree
|
210 |
+
* number of hydrogens
|
211 |
+
* mass
|
212 |
+
"""
|
213 |
+
|
214 |
+
def __init__(
|
215 |
+
self,
|
216 |
+
atomic_nums: Sequence[int] | None = None,
|
217 |
+
degrees: Sequence[int] | None = None,
|
218 |
+
num_Hs: Sequence[int] | None = None,
|
219 |
+
):
|
220 |
+
self.atomic_nums = {j: i for i, j in enumerate(atomic_nums or list(range(1, 37)) + [53])}
|
221 |
+
self.degrees = {i: i for i in (degrees or list(range(6)))}
|
222 |
+
self.num_Hs = {i: i for i in (num_Hs or list(range(5)))}
|
223 |
+
|
224 |
+
self._subfeats: list[dict] = [self.atomic_nums, self.degrees, self.num_Hs]
|
225 |
+
subfeat_sizes = [1 + len(self.atomic_nums), 1 + len(self.degrees), 1 + len(self.num_Hs), 1]
|
226 |
+
self.__size = sum(subfeat_sizes)
|
227 |
+
|
228 |
+
def __len__(self) -> int:
|
229 |
+
return self.__size
|
230 |
+
|
231 |
+
def __call__(self, a: Atom | None) -> np.ndarray:
|
232 |
+
x = np.zeros(self.__size)
|
233 |
+
|
234 |
+
if a is None:
|
235 |
+
return x
|
236 |
+
|
237 |
+
feats = [a.GetAtomicNum(), a.GetTotalDegree(), int(a.GetTotalNumHs())]
|
238 |
+
i = 0
|
239 |
+
for feat, choices in zip(feats, self._subfeats):
|
240 |
+
j = choices.get(feat, len(choices))
|
241 |
+
x[i + j] = 1
|
242 |
+
i += len(choices) + 1
|
243 |
+
x[i] = 0.01 * a.GetMass() # scaled to about the same range as other features
|
244 |
+
|
245 |
+
return x
|
246 |
+
|
247 |
+
def num_only(self, a: Atom) -> np.ndarray:
|
248 |
+
"""featurize the atom by setting only the atomic number bit"""
|
249 |
+
x = np.zeros(len(self))
|
250 |
+
|
251 |
+
if a is None:
|
252 |
+
return x
|
253 |
+
|
254 |
+
i = self.atomic_nums.get(a.GetAtomicNum(), len(self.atomic_nums))
|
255 |
+
x[i] = 1
|
256 |
+
|
257 |
+
return x
|
258 |
+
|
259 |
+
|
260 |
+
class AtomFeatureMode(EnumMapping):
|
261 |
+
"""The mode of an atom is used for featurization into a `MolGraph`"""
|
262 |
+
|
263 |
+
V1 = auto()
|
264 |
+
V2 = auto()
|
265 |
+
ORGANIC = auto()
|
266 |
+
RIGR = auto()
|
267 |
+
|
268 |
+
|
269 |
+
def get_multi_hot_atom_featurizer(mode: str | AtomFeatureMode) -> MultiHotAtomFeaturizer:
|
270 |
+
"""Build the corresponding multi-hot atom featurizer."""
|
271 |
+
match AtomFeatureMode.get(mode):
|
272 |
+
case AtomFeatureMode.V1:
|
273 |
+
return MultiHotAtomFeaturizer.v1()
|
274 |
+
case AtomFeatureMode.V2:
|
275 |
+
return MultiHotAtomFeaturizer.v2()
|
276 |
+
case AtomFeatureMode.ORGANIC:
|
277 |
+
return MultiHotAtomFeaturizer.organic()
|
278 |
+
case AtomFeatureMode.RIGR:
|
279 |
+
return RIGRAtomFeaturizer()
|
280 |
+
case _:
|
281 |
+
raise RuntimeError("unreachable code reached!")
|
chemprop-updated/chemprop/featurizers/base.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from collections.abc import Sized
|
3 |
+
from typing import Generic, TypeVar
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from chemprop.data.molgraph import MolGraph
|
8 |
+
|
9 |
+
S = TypeVar("S")
|
10 |
+
T = TypeVar("T")
|
11 |
+
|
12 |
+
|
13 |
+
class Featurizer(Generic[S, T]):
|
14 |
+
"""An :class:`Featurizer` featurizes inputs type ``S`` into outputs of
|
15 |
+
type ``T``."""
|
16 |
+
|
17 |
+
@abstractmethod
|
18 |
+
def __call__(self, input: S, *args, **kwargs) -> T:
|
19 |
+
"""featurize an input"""
|
20 |
+
|
21 |
+
|
22 |
+
class VectorFeaturizer(Featurizer[S, np.ndarray], Sized):
|
23 |
+
...
|
24 |
+
|
25 |
+
|
26 |
+
class GraphFeaturizer(Featurizer[S, MolGraph]):
|
27 |
+
@property
|
28 |
+
@abstractmethod
|
29 |
+
def shape(self) -> tuple[int, int]:
|
30 |
+
...
|