Spaces:
Build error
Build error
Upload 625 files
Browse filesadd model and bbbp model weight
This view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- chemprop-updated/.bumpversion.cfg +10 -0
- chemprop-updated/.dockerignore +3 -0
- chemprop-updated/.flake8 +9 -0
- chemprop-updated/.github/ISSUE_TEMPLATE/todo.md +11 -0
- chemprop-updated/.github/ISSUE_TEMPLATE/v1_bug_report.md +35 -0
- chemprop-updated/.github/ISSUE_TEMPLATE/v1_question.md +17 -0
- chemprop-updated/.github/ISSUE_TEMPLATE/v2_bug_report.md +35 -0
- chemprop-updated/.github/ISSUE_TEMPLATE/v2_feature_request.md +23 -0
- chemprop-updated/.github/ISSUE_TEMPLATE/v2_question.md +17 -0
- chemprop-updated/.github/PULL_REQUEST_TEMPLATE.md +18 -0
- chemprop-updated/.github/PULL_REQUEST_TEMPLATE/bugfix.md +12 -0
- chemprop-updated/.github/PULL_REQUEST_TEMPLATE/new_feature.md +15 -0
- chemprop-updated/.github/workflows/ci.yml +158 -0
- chemprop-updated/.gitignore +178 -0
- chemprop-updated/.readthedocs.yml +19 -0
- chemprop-updated/CITATIONS.bib +37 -0
- chemprop-updated/CONTRIBUTING.md +40 -0
- chemprop-updated/Dockerfile +50 -0
- chemprop-updated/LICENSE.txt +27 -0
- chemprop-updated/README.md +63 -0
- chemprop-updated/chemprop/__init__.py +5 -0
- chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc +0 -0
- chemprop-updated/chemprop/cli/common.py +211 -0
- chemprop-updated/chemprop/cli/conf.py +9 -0
- chemprop-updated/chemprop/cli/convert.py +55 -0
- chemprop-updated/chemprop/cli/fingerprint.py +182 -0
- chemprop-updated/chemprop/cli/hpopt.py +537 -0
- chemprop-updated/chemprop/cli/main.py +85 -0
- chemprop-updated/chemprop/cli/predict.py +444 -0
- chemprop-updated/chemprop/cli/train.py +1340 -0
- chemprop-updated/chemprop/cli/utils/__init__.py +30 -0
- chemprop-updated/chemprop/cli/utils/actions.py +19 -0
- chemprop-updated/chemprop/cli/utils/args.py +34 -0
- chemprop-updated/chemprop/cli/utils/command.py +24 -0
- chemprop-updated/chemprop/cli/utils/parsing.py +446 -0
- chemprop-updated/chemprop/cli/utils/utils.py +31 -0
- chemprop-updated/chemprop/conf.py +6 -0
.gitattributes
CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
chemprop-updated/docs/source/_static/images/message_passing.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
chemprop/docs/source/_static/images/message_passing.png filter=lfs diff=lfs merge=lfs -text
|
chemprop-updated/.bumpversion.cfg
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[bumpversion]
|
2 |
+
current_version = 2.1.2
|
3 |
+
commit = True
|
4 |
+
tag = True
|
5 |
+
|
6 |
+
[bumpversion:file:pyproject.toml]
|
7 |
+
|
8 |
+
[bumpversion:file:chemprop/__init__.py]
|
9 |
+
|
10 |
+
[bumpversion:file:docs/source/conf.py]
|
chemprop-updated/.dockerignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
**.git*
|
2 |
+
.dockerignore
|
3 |
+
Dockerfile
|
chemprop-updated/.flake8
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
ignore = E203, E266, E501, F403, E741, W503, W605
|
3 |
+
max-line-length = 100
|
4 |
+
max-complexity = 18
|
5 |
+
per-file-ignores =
|
6 |
+
__init__.py: F401
|
7 |
+
chemprop/nn/predictors.py: F405
|
8 |
+
chemprop/nn/metrics.py: F405
|
9 |
+
tests/unit/nn/test_metrics.py: E121, E122, E131, E241, W291
|
chemprop-updated/.github/ISSUE_TEMPLATE/todo.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: to-do
|
3 |
+
about: Add an item to the to-do list. More generic than a feature request
|
4 |
+
title: "[TODO]: "
|
5 |
+
labels: todo
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Notes**
|
11 |
+
_these could be implementation or more specific details to keep in mind, if they'll be helpful for issue tracking_
|
chemprop-updated/.github/ISSUE_TEMPLATE/v1_bug_report.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: v1 Bug Report
|
3 |
+
about: Report a bug in v1 (will not be fixed)
|
4 |
+
title: "[v1 BUG]: "
|
5 |
+
labels: bug, v1-wontfix
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Describe the bug**
|
11 |
+
A clear and concise description of what the bug is.
|
12 |
+
|
13 |
+
**Example(s)**
|
14 |
+
Provide some examples of where the current code fails. Feel free to share your actual code for additional context, but a minimal and isolated example is preferred.
|
15 |
+
|
16 |
+
**Expected behavior**
|
17 |
+
A clear and concise description of what you expected to happen. If there is correct, expected output, include that here as well.
|
18 |
+
|
19 |
+
**Error Stack Trace**
|
20 |
+
If the bug is resulting in an error message, provide the _full_ stack trace (not just the last line). This is helpful for debugging, especially in cases where you aren't able to provide a minimum/isolated working example with accompanying files.
|
21 |
+
|
22 |
+
**Screenshots**
|
23 |
+
If applicable, add screenshots to help explain your problem.
|
24 |
+
|
25 |
+
**Environment**
|
26 |
+
- python version
|
27 |
+
- package versions: `conda list` or `pip list`
|
28 |
+
- OS
|
29 |
+
|
30 |
+
**Checklist**
|
31 |
+
- [ ] all dependencies are satisifed: `conda list` or `pip list` shows the packages listed in the `pyproject.toml`
|
32 |
+
- [ ] the unit tests are working: `pytest -v` reports no errors
|
33 |
+
|
34 |
+
**Additional context**
|
35 |
+
Add any other context about the problem here.
|
chemprop-updated/.github/ISSUE_TEMPLATE/v1_question.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: v1 Question
|
3 |
+
about: Have a question about how to use Chemprop v1?
|
4 |
+
title: "[v1 QUESTION]: "
|
5 |
+
labels: question
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**What are you trying to do?**
|
11 |
+
Please tell us what you're trying to do with Chemprop, providing as much detail as possible
|
12 |
+
|
13 |
+
**Previous attempts**
|
14 |
+
If possible, provide some examples of what you've already tried and what the output was.
|
15 |
+
|
16 |
+
**Screenshots**
|
17 |
+
If applicable, add screenshots to help explain your problem.
|
chemprop-updated/.github/ISSUE_TEMPLATE/v2_bug_report.md
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: v2 Bug Report
|
3 |
+
about: Create a report to help us improve
|
4 |
+
title: "[v2 BUG]: "
|
5 |
+
labels: bug
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Describe the bug**
|
11 |
+
A clear and concise description of what the bug is.
|
12 |
+
|
13 |
+
**Example(s)**
|
14 |
+
Provide some examples of where the current code fails. Feel free to share your actual code for additional context, but a minimal and isolated example is preferred.
|
15 |
+
|
16 |
+
**Expected behavior**
|
17 |
+
A clear and concise description of what you expected to happen. If there is correct, expected output, include that here as well.
|
18 |
+
|
19 |
+
**Error Stack Trace**
|
20 |
+
If the bug is resulting in an error message, provide the _full_ stack trace (not just the last line). This is helpful for debugging, especially in cases where you aren't able to provide a minimum/isolated working example with accompanying files.
|
21 |
+
|
22 |
+
**Screenshots**
|
23 |
+
If applicable, add screenshots to help explain your problem.
|
24 |
+
|
25 |
+
**Environment**
|
26 |
+
- python version
|
27 |
+
- package versions: `conda list` or `pip list`
|
28 |
+
- OS
|
29 |
+
|
30 |
+
**Checklist**
|
31 |
+
- [ ] all dependencies are satisifed: `conda list` or `pip list` shows the packages listed in the `pyproject.toml`
|
32 |
+
- [ ] the unit tests are working: `pytest -v` reports no errors
|
33 |
+
|
34 |
+
**Additional context**
|
35 |
+
Add any other context about the problem here.
|
chemprop-updated/.github/ISSUE_TEMPLATE/v2_feature_request.md
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: v2 Feature Request
|
3 |
+
about: Suggest an idea for this project
|
4 |
+
title: "[v2 FEATURE]: "
|
5 |
+
labels: enhancement
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**Is your feature request related to a problem? Please describe.**
|
11 |
+
A clear and concise description of what the problem is.
|
12 |
+
|
13 |
+
**Use-cases/examples of this new feature**
|
14 |
+
What are some example workflows that would employ this new feature? Are there any relevant issues?
|
15 |
+
|
16 |
+
**Desired solution/workflow**
|
17 |
+
A clear and concise description of what you want to happen. Include some (pseudo)code, if possible
|
18 |
+
|
19 |
+
**Discussion**
|
20 |
+
What are some considerations around this new feature? Are there alternative approaches to consider? What should the scope of the feature be?
|
21 |
+
|
22 |
+
**Additional context**
|
23 |
+
Add any other context or screenshots about the feature request here.
|
chemprop-updated/.github/ISSUE_TEMPLATE/v2_question.md
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
name: v2 Question
|
3 |
+
about: Have a question about how to use Chemprop v2?
|
4 |
+
title: "[v2 QUESTION]: "
|
5 |
+
labels: question
|
6 |
+
assignees: ''
|
7 |
+
|
8 |
+
---
|
9 |
+
|
10 |
+
**What are you trying to do?**
|
11 |
+
Please tell us what you're trying to do with Chemprop, providing as much detail as possible
|
12 |
+
|
13 |
+
**Previous attempts**
|
14 |
+
If possible, provide some examples of what you've already tried and what the output was.
|
15 |
+
|
16 |
+
**Screenshots**
|
17 |
+
If applicable, add screenshots to help explain your problem.
|
chemprop-updated/.github/PULL_REQUEST_TEMPLATE.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Description
|
2 |
+
Include a brief summary of the bug/feature/etc. that this PR seeks to address
|
3 |
+
|
4 |
+
## Example / Current workflow
|
5 |
+
Include a sample workflow to either **(a)** reproduce the bug with current codebase or **(b)** showcase the deficiency does this PR seeks to address
|
6 |
+
|
7 |
+
## Bugfix / Desired workflow
|
8 |
+
Include either **(a)** the same workflow from above with the correct output produced via this PR **(b)** some (pseudo)code containing the new workflow that this PR will (seek to) implement
|
9 |
+
|
10 |
+
## Questions
|
11 |
+
If there are open questions about implementation strategy or scope of the PR, include them here
|
12 |
+
|
13 |
+
## Relevant issues
|
14 |
+
If appropriate, please tag them here and include a quick summary
|
15 |
+
|
16 |
+
## Checklist
|
17 |
+
- [ ] linted with flake8?
|
18 |
+
- [ ] (if appropriate) unit tests added?
|
chemprop-updated/.github/PULL_REQUEST_TEMPLATE/bugfix.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Bug report
|
2 |
+
Include a brief summary of the bug that this PR seeks to address. If possible, include relevant issue tags
|
3 |
+
|
4 |
+
## Example
|
5 |
+
Include a sample execution to reproduce the bug with current codebase, and some sample output showcasing that the PR fixes this bug
|
6 |
+
|
7 |
+
## Questions
|
8 |
+
If there are open questions about implementation strategy or scope of the PR, include them here
|
9 |
+
|
10 |
+
## Checklist
|
11 |
+
- [ ] linted with flake8?
|
12 |
+
- [ ] (if necessary) appropriate unit tests added?
|
chemprop-updated/.github/PULL_REQUEST_TEMPLATE/new_feature.md
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Statement of need
|
2 |
+
What deficiency does this PR seek to address? If there are relevant issues, please tag them here
|
3 |
+
|
4 |
+
## Current workflow
|
5 |
+
How is this need achieved with the current codebase?
|
6 |
+
|
7 |
+
## Desired workflow
|
8 |
+
Include some (pseudo)code containing the new workflow that this PR will (seek to) implement
|
9 |
+
|
10 |
+
## Questions
|
11 |
+
If there are open questions about implementation strategy or scope of the PR, include them here
|
12 |
+
|
13 |
+
## Checklist
|
14 |
+
- [ ] linted with flake8?
|
15 |
+
- [ ] appropriate unit tests added?
|
chemprop-updated/.github/workflows/ci.yml
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ci.yml
|
2 |
+
#
|
3 |
+
# Continuous Integration for Chemprop - checks build, code formatting, and runs tests for all
|
4 |
+
# proposed changes and on a regular schedule
|
5 |
+
#
|
6 |
+
# Note: this file contains extensive inline documentation to aid with knowledge transfer.
|
7 |
+
|
8 |
+
name: Continuous Integration
|
9 |
+
|
10 |
+
on:
|
11 |
+
# run on pushes/pull requests to/against main
|
12 |
+
push:
|
13 |
+
branches: [main]
|
14 |
+
pull_request:
|
15 |
+
branches: [main]
|
16 |
+
# run this in the morning on weekdays to catch dependency issues
|
17 |
+
schedule:
|
18 |
+
- cron: "0 8 * * 1-5"
|
19 |
+
# allow manual runs
|
20 |
+
workflow_dispatch:
|
21 |
+
|
22 |
+
# cancel previously running tests if new commits are made
|
23 |
+
# https://docs.github.com/en/actions/examples/using-concurrency-expressions-and-a-test-matrix
|
24 |
+
concurrency:
|
25 |
+
group: actions-id-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
|
26 |
+
cancel-in-progress: true
|
27 |
+
|
28 |
+
env:
|
29 |
+
USE_LIBUV: 0 # libuv doesn't work on GitHub actions Windows runner
|
30 |
+
|
31 |
+
jobs:
|
32 |
+
build:
|
33 |
+
name: Check Build
|
34 |
+
runs-on: ubuntu-latest
|
35 |
+
steps:
|
36 |
+
# clone the repo, attempt to build
|
37 |
+
- uses: actions/checkout@v4
|
38 |
+
- run: python -m pip install build
|
39 |
+
- run: python -m build .
|
40 |
+
|
41 |
+
lint:
|
42 |
+
name: Check Formatting
|
43 |
+
needs: build
|
44 |
+
runs-on: ubuntu-latest
|
45 |
+
steps:
|
46 |
+
# clone the repo, run black and flake8 on it
|
47 |
+
- uses: actions/checkout@v4
|
48 |
+
- run: python -m pip install black==23.* flake8 isort
|
49 |
+
- run: black --check .
|
50 |
+
- run: flake8 .
|
51 |
+
- run: isort --check .
|
52 |
+
|
53 |
+
test:
|
54 |
+
name: Execute Tests
|
55 |
+
needs: lint
|
56 |
+
runs-on: ${{ matrix.os }}
|
57 |
+
defaults:
|
58 |
+
run:
|
59 |
+
# run with a login shell (so that the conda environment is activated)
|
60 |
+
# and echo the commands we run as we do them (for debugging purposes)
|
61 |
+
shell: bash -el {0}
|
62 |
+
strategy:
|
63 |
+
# if one platform/python version fails, continue testing the others
|
64 |
+
fail-fast: false
|
65 |
+
matrix:
|
66 |
+
# test on all platforms with both supported versions of Python
|
67 |
+
os: [ubuntu-latest, macos-13, windows-latest]
|
68 |
+
python-version: [3.11, 3.12]
|
69 |
+
steps:
|
70 |
+
- uses: actions/checkout@v4
|
71 |
+
# use a version of the conda virtual environment manager to set up an
|
72 |
+
# isolated environment with the Python version we want
|
73 |
+
- uses: conda-incubator/setup-miniconda@v3
|
74 |
+
with:
|
75 |
+
python-version: ${{ matrix.python-version }}
|
76 |
+
auto-update-conda: true
|
77 |
+
show-channel-urls: true
|
78 |
+
conda-remove-defaults: "true"
|
79 |
+
environment-file: environment.yml
|
80 |
+
activate-environment: chemprop
|
81 |
+
- name: Install dependencies
|
82 |
+
shell: bash -l {0}
|
83 |
+
run: |
|
84 |
+
python -m pip install nbmake
|
85 |
+
python -m pip install ".[dev,docs,test,hpopt]"
|
86 |
+
- name: Test with pytest
|
87 |
+
shell: bash -l {0}
|
88 |
+
run: |
|
89 |
+
pytest -v tests
|
90 |
+
- name: Test notebooks
|
91 |
+
shell: bash -l {0}
|
92 |
+
run: |
|
93 |
+
python -m pip install matplotlib
|
94 |
+
pytest --no-cov -v --nbmake $(find examples -name '*.ipynb' ! -name 'use_featurizer_with_other_libraries.ipynb' ! -name 'shapley_value_with_customized_featurizers.ipynb')
|
95 |
+
pytest --no-cov -v --nbmake $(find docs/source/tutorial/python -name "*.ipynb")
|
96 |
+
pypi:
|
97 |
+
name: Build and publish Python 🐍 distributions 📦 to PyPI
|
98 |
+
runs-on: ubuntu-latest
|
99 |
+
# only run if the tests pass
|
100 |
+
needs: [test]
|
101 |
+
# run only on pushes to main on chemprop
|
102 |
+
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' && github.repository == 'chemprop/chemprop'}}
|
103 |
+
steps:
|
104 |
+
- uses: actions/checkout@master
|
105 |
+
- name: Set up Python 3.12
|
106 |
+
uses: actions/setup-python@v3
|
107 |
+
with:
|
108 |
+
python-version: "3.11"
|
109 |
+
- name: Install pypa/build
|
110 |
+
run: >-
|
111 |
+
python -m
|
112 |
+
pip install
|
113 |
+
build
|
114 |
+
--user
|
115 |
+
- name: Build a binary wheel and a source tarball
|
116 |
+
run: >-
|
117 |
+
python -m
|
118 |
+
build
|
119 |
+
--sdist
|
120 |
+
--wheel
|
121 |
+
--outdir dist/
|
122 |
+
.
|
123 |
+
- name: Publish distribution 📦 to PyPI
|
124 |
+
uses: pypa/gh-action-pypi-publish@release/v1
|
125 |
+
with:
|
126 |
+
password: ${{ secrets.PYPI_API_TOKEN }}
|
127 |
+
skip-existing: true
|
128 |
+
verbose: true
|
129 |
+
|
130 |
+
build-and-push-docker:
|
131 |
+
# shamelessly copied from:
|
132 |
+
# https://github.com/ReactionMechanismGenerator/RMG-Py/blob/bfaee1cad9909a17103a8e6ef9a22569c475964c/.github/workflows/CI.yml#L359C1-L386C54
|
133 |
+
# which is also shamelessly copied from somewhere
|
134 |
+
runs-on: ubuntu-latest
|
135 |
+
# only run if the tests pass
|
136 |
+
needs: [test]
|
137 |
+
# run only on pushes to main on chemprop
|
138 |
+
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' && github.repository == 'chemprop/chemprop'}}
|
139 |
+
steps:
|
140 |
+
- name: Set up QEMU
|
141 |
+
uses: docker/setup-qemu-action@v2
|
142 |
+
|
143 |
+
- name: Set up Docker Buildx
|
144 |
+
uses: docker/setup-buildx-action@v2
|
145 |
+
|
146 |
+
- name: Login to Docker Hub
|
147 |
+
uses: docker/login-action@v2
|
148 |
+
with:
|
149 |
+
# repository secretes managed by the maintainers
|
150 |
+
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
151 |
+
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
152 |
+
|
153 |
+
- name: Build and Push
|
154 |
+
uses: docker/build-push-action@v4
|
155 |
+
with:
|
156 |
+
push: true
|
157 |
+
tags: chemprop/chemprop:latest
|
158 |
+
|
chemprop-updated/.gitignore
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
*.idea
|
163 |
+
*.DS_Store
|
164 |
+
*.vscode
|
165 |
+
*.csv
|
166 |
+
*.pkl
|
167 |
+
*.pt
|
168 |
+
*.json
|
169 |
+
*.sqlite3
|
170 |
+
*.yaml
|
171 |
+
*.tfevents.*
|
172 |
+
*.ckpt
|
173 |
+
chemprop/_version.py
|
174 |
+
*.ckpt
|
175 |
+
*.ipynb
|
176 |
+
config.toml
|
177 |
+
|
178 |
+
!tests/data/*
|
chemprop-updated/.readthedocs.yml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# .readthedocs.yml
|
2 |
+
# Read the Docs configuration file
|
3 |
+
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
4 |
+
|
5 |
+
# Required
|
6 |
+
version: 2
|
7 |
+
|
8 |
+
# Set the OS, Python version and other tools you might need
|
9 |
+
build:
|
10 |
+
os: ubuntu-22.04
|
11 |
+
tools:
|
12 |
+
python: "3.11"
|
13 |
+
jobs:
|
14 |
+
post_install:
|
15 |
+
- python -m pip install --upgrade --upgrade-strategy only-if-needed --no-cache-dir ".[docs]"
|
16 |
+
|
17 |
+
# Build documentation in the docs/ directory with Sphinx
|
18 |
+
sphinx:
|
19 |
+
configuration: docs/source/conf.py
|
chemprop-updated/CITATIONS.bib
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this was downloaded from ACS: https://pubs.acs.org/doi/10.1021/acs.jcim.9b00237
|
2 |
+
@article{chemprop_theory,
|
3 |
+
author = {Yang, Kevin and Swanson, Kyle and Jin, Wengong and Coley, Connor and Eiden, Philipp and Gao, Hua and Guzman-Perez, Angel and Hopper, Timothy and Kelley, Brian and Mathea, Miriam and Palmer, Andrew and Settels, Volker and Jaakkola, Tommi and Jensen, Klavs and Barzilay, Regina},
|
4 |
+
title = {Analyzing Learned Molecular Representations for Property Prediction},
|
5 |
+
journal = {Journal of Chemical Information and Modeling},
|
6 |
+
volume = {59},
|
7 |
+
number = {8},
|
8 |
+
pages = {3370-3388},
|
9 |
+
year = {2019},
|
10 |
+
doi = {10.1021/acs.jcim.9b00237},
|
11 |
+
note ={PMID: 31361484},
|
12 |
+
URL = {
|
13 |
+
https://doi.org/10.1021/acs.jcim.9b00237
|
14 |
+
},
|
15 |
+
eprint = {
|
16 |
+
https://doi.org/10.1021/acs.jcim.9b00237
|
17 |
+
}
|
18 |
+
}
|
19 |
+
|
20 |
+
# this was downloaded from ACS: https://pubs.acs.org/doi/10.1021/acs.jcim.3c01250
|
21 |
+
@article{chemprop_software,
|
22 |
+
author = {Heid, Esther and Greenman, Kevin P. and Chung, Yunsie and Li, Shih-Cheng and Graff, David E. and Vermeire, Florence H. and Wu, Haoyang and Green, William H. and McGill, Charles J.},
|
23 |
+
title = {Chemprop: A Machine Learning Package for Chemical Property Prediction},
|
24 |
+
journal = {Journal of Chemical Information and Modeling},
|
25 |
+
volume = {64},
|
26 |
+
number = {1},
|
27 |
+
pages = {9-17},
|
28 |
+
year = {2024},
|
29 |
+
doi = {10.1021/acs.jcim.3c01250},
|
30 |
+
note ={PMID: 38147829},
|
31 |
+
URL = {
|
32 |
+
https://doi.org/10.1021/acs.jcim.3c01250
|
33 |
+
},
|
34 |
+
eprint = {
|
35 |
+
https://doi.org/10.1021/acs.jcim.3c01250
|
36 |
+
}
|
37 |
+
}
|
chemprop-updated/CONTRIBUTING.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# How to contribute
|
2 |
+
|
3 |
+
We welcome contributions from external contributors, and this document
|
4 |
+
describes how to merge code changes into this repository.
|
5 |
+
|
6 |
+
## Getting Started
|
7 |
+
|
8 |
+
* Make sure you have a [GitHub account](https://github.com/signup/free).
|
9 |
+
* [Fork](https://help.github.com/articles/fork-a-repo/) this repository on GitHub.
|
10 |
+
* On your local machine,
|
11 |
+
[clone](https://help.github.com/articles/cloning-a-repository/) your fork of
|
12 |
+
the repository.
|
13 |
+
|
14 |
+
## Making Changes
|
15 |
+
|
16 |
+
* Add some really awesome code to your local fork. It's usually a [good
|
17 |
+
idea](http://blog.jasonmeridth.com/posts/do-not-issue-pull-requests-from-your-master-branch/)
|
18 |
+
to make changes on a
|
19 |
+
[branch](https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/)
|
20 |
+
with the branch name relating to the feature you are going to add.
|
21 |
+
* When you are ready for others to examine and comment on your new feature,
|
22 |
+
navigate to your fork of `chemprop` on GitHub and open a [pull
|
23 |
+
request](https://help.github.com/articles/using-pull-requests/) (PR). Note that
|
24 |
+
after you launch a PR from one of your fork's branches, all
|
25 |
+
subsequent commits to that branch will be added to the open pull request
|
26 |
+
automatically. Each commit added to the PR will be validated for
|
27 |
+
mergability, compilation and test suite compliance; the results of these tests
|
28 |
+
will be visible on the PR page.
|
29 |
+
* If you're providing a new feature, you **must** add test cases and documentation.
|
30 |
+
* When the code is ready to go, run the test suite: `pytest`.
|
31 |
+
* When you're ready to be considered for merging, click the "Ready for review"
|
32 |
+
box on the PR page to let the Chemprop devs know that the changes are complete.
|
33 |
+
The code will not be merged until the continuous integration returns checkmarks,
|
34 |
+
and at least one core developer gives "Approved" reviews.
|
35 |
+
|
36 |
+
## Additional Resources
|
37 |
+
|
38 |
+
* [General GitHub documentation](https://help.github.com/)
|
39 |
+
* [PR best practices](http://codeinthehole.com/writing/pull-requests-and-other-good-practices-for-teams-using-github/)
|
40 |
+
* [A guide to contributing to software packages](http://www.contribution-guide.org)
|
chemprop-updated/Dockerfile
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Dockerfile
|
2 |
+
#
|
3 |
+
# Builds a Docker image containing Chemprop and its required dependencies.
|
4 |
+
#
|
5 |
+
# Build this image with:
|
6 |
+
# git clone https://github.com/chemprop/chemprop.git
|
7 |
+
# cd chemprop
|
8 |
+
# docker build --tag=chemprop:latest .
|
9 |
+
#
|
10 |
+
# Run the built image with:
|
11 |
+
# docker run --name chemprop_container -it chemprop:latest
|
12 |
+
#
|
13 |
+
# Note:
|
14 |
+
# This image only runs on CPU - we do not provide a Dockerfile
|
15 |
+
# for GPU use (see installation documentation).
|
16 |
+
|
17 |
+
# Parent Image
|
18 |
+
FROM continuumio/miniconda3:latest
|
19 |
+
|
20 |
+
# Install libxrender1 (required by RDKit) and then clean up
|
21 |
+
RUN apt-get update && \
|
22 |
+
apt-get install -y \
|
23 |
+
libxrender1 && \
|
24 |
+
apt-get autoremove -y && \
|
25 |
+
apt-get clean -y
|
26 |
+
|
27 |
+
WORKDIR /opt/chemprop
|
28 |
+
|
29 |
+
# build an empty conda environment with appropriate Python version
|
30 |
+
RUN conda create --name chemprop_env python=3.11*
|
31 |
+
|
32 |
+
# This runs all subsequent commands inside the chemprop_env conda environment
|
33 |
+
#
|
34 |
+
# Analogous to just activating the environment, which we can't actually do here
|
35 |
+
# since that requires running conda init and restarting the shell (not possible
|
36 |
+
# in a Dockerfile build script)
|
37 |
+
SHELL ["conda", "run", "--no-capture-output", "-n", "chemprop_env", "/bin/bash", "-c"]
|
38 |
+
|
39 |
+
# Follow the installation instructions then clear the cache
|
40 |
+
ADD chemprop chemprop
|
41 |
+
ENV PYTHONPATH /opt/chemprop
|
42 |
+
ADD LICENSE.txt pyproject.toml README.md ./
|
43 |
+
RUN conda install pytorch cpuonly -c pytorch && \
|
44 |
+
conda clean --all --yes && \
|
45 |
+
python -m pip install . && \
|
46 |
+
python -m pip cache purge
|
47 |
+
|
48 |
+
# when running this image, open an interactive bash terminal inside the conda environment
|
49 |
+
RUN echo "conda activate chemprop_env" > ~/.bashrc
|
50 |
+
ENTRYPOINT ["/bin/bash", "--login"]
|
chemprop-updated/LICENSE.txt
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 The Chemprop Development Team (Regina Barzilay,
|
4 |
+
Jackson Burns, Yunsie Chung, Anna Doner, Xiaorui Dong, David Graff,
|
5 |
+
William Green, Kevin Greenman, Yanfei Guan, Esther Heid, Lior Hirschfeld,
|
6 |
+
Tommi Jaakkola, Wengong Jin, Olivier Lafontant-Joseph, Shih-Cheng Li,
|
7 |
+
Mengjie Liu, Joel Manu, Charles McGill, Angiras Menon, Nathan Morgan,
|
8 |
+
Hao-Wei Pang, Kevin Spiekermann, Kyle Swanson, Allison Tam,
|
9 |
+
Florence Vermeire, Haoyang Wu, and Kevin Yang, Jonathan Zheng)
|
10 |
+
|
11 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
12 |
+
of this software and associated documentation files (the "Software"), to deal
|
13 |
+
in the Software without restriction, including without limitation the rights
|
14 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
15 |
+
copies of the Software, and to permit persons to whom the Software is
|
16 |
+
furnished to do so, subject to the following conditions:
|
17 |
+
|
18 |
+
The above copyright notice and this permission notice shall be included in all
|
19 |
+
copies or substantial portions of the Software.
|
20 |
+
|
21 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
22 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
23 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
24 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
25 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
26 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
27 |
+
SOFTWARE.
|
chemprop-updated/README.md
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<picture>
|
2 |
+
<source media="(prefers-color-scheme: dark)" srcset="docs/source/_static/images/logo/chemprop_logo_dark_mode.svg">
|
3 |
+
<img alt="ChemProp Logo" src="docs/source/_static/images/logo/chemprop_logo.svg">
|
4 |
+
</picture>
|
5 |
+
|
6 |
+
# Chemprop
|
7 |
+
|
8 |
+
[](https://badge.fury.io/py/chemprop)
|
9 |
+
[](https://badge.fury.io/py/chemprop)
|
10 |
+
[](https://anaconda.org/conda-forge/chemprop)
|
11 |
+
[](https://github.com/chemprop/chemprop/actions/workflows/tests.yml)
|
12 |
+
[](https://chemprop.readthedocs.io/en/main/?badge=main)
|
13 |
+
[](https://opensource.org/licenses/MIT)
|
14 |
+
[](https://pepy.tech/project/chemprop)
|
15 |
+
[](https://pepy.tech/project/chemprop)
|
16 |
+
[](https://pepy.tech/project/chemprop)
|
17 |
+
|
18 |
+
Chemprop is a repository containing message passing neural networks for molecular property prediction.
|
19 |
+
|
20 |
+
Documentation can be found [here](https://chemprop.readthedocs.io/en/main/).
|
21 |
+
|
22 |
+
There are tutorial notebooks in the [`examples/`](https://github.com/chemprop/chemprop/tree/main/examples) directory.
|
23 |
+
|
24 |
+
Chemprop recently underwent a ground-up rewrite and new major release (v2.0.0). A helpful transition guide from Chemprop v1 to v2 can be found [here](https://docs.google.com/spreadsheets/u/3/d/e/2PACX-1vRshySIknVBBsTs5P18jL4WeqisxDAnDE5VRnzxqYEhYrMe4GLS17w5KeKPw9sged6TmmPZ4eEZSTIy/pubhtml). This includes a side-by-side comparison of CLI argument options, a list of which arguments will be implemented in later versions of v2, and a list of changes to default hyperparameters.
|
25 |
+
|
26 |
+
**License:** Chemprop is free to use under the [MIT License](LICENSE.txt). The Chemprop logo is free to use under [CC0 1.0](docs/source/_static/images/logo/LICENSE.txt).
|
27 |
+
|
28 |
+
**References**: Please cite the appropriate papers if Chemprop is helpful to your research.
|
29 |
+
|
30 |
+
- Chemprop was initially described in the papers [Analyzing Learned Molecular Representations for Property Prediction](https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00237) for molecules and [Machine Learning of Reaction Properties via Learned Representations of the Condensed Graph of Reaction](https://doi.org/10.1021/acs.jcim.1c00975) for reactions.
|
31 |
+
- The interpretation functionality (available in v1, but not yet implemented in v2) is based on the paper [Multi-Objective Molecule Generation using Interpretable Substructures](https://arxiv.org/abs/2002.03244).
|
32 |
+
- Chemprop now has its own dedicated manuscript that describes and benchmarks it in more detail: [Chemprop: A Machine Learning Package for Chemical Property Prediction](https://doi.org/10.1021/acs.jcim.3c01250).
|
33 |
+
- A paper describing and benchmarking the changes in v2.0.0 is forthcoming.
|
34 |
+
|
35 |
+
**Selected Applications**: Chemprop has been successfully used in the following works.
|
36 |
+
|
37 |
+
- [A Deep Learning Approach to Antibiotic Discovery](https://www.cell.com/cell/fulltext/S0092-8674(20)30102-1) - _Cell_ (2020): Chemprop was used to predict antibiotic activity against _E. coli_, leading to the discovery of [Halicin](https://en.wikipedia.org/wiki/Halicin), a novel antibiotic candidate. Model checkpoints are availabile on [Zenodo](https://doi.org/10.5281/zenodo.6527882).
|
38 |
+
- [Discovery of a structural class of antibiotics with explainable deep learning](https://www.nature.com/articles/s41586-023-06887-8) - _Nature_ (2023): Identified a structural class of antibiotics selective against methicillin-resistant _S. aureus_ (MRSA) and vancomycin-resistant enterococci using ensembles of Chemprop models, and explained results using Chemprop's interpret method.
|
39 |
+
- [ADMET-AI: A machine learning ADMET platform for evaluation of large-scale chemical libraries](https://academic.oup.com/bioinformatics/advance-article/doi/10.1093/bioinformatics/btae416/7698030?utm_source=authortollfreelink&utm_campaign=bioinformatics&utm_medium=email&guestAccessKey=f4fca1d2-49ec-4b10-b476-5aea3bf37045): Chemprop was trained on 41 absorption, distribution, metabolism, excretion, and toxicity (ADMET) datasets from the [Therapeutics Data Commons](https://tdcommons.ai). The Chemprop models in ADMET-AI are available both as a web server at [admet.ai.greenstonebio.com](https://admet.ai.greenstonebio.com) and as a Python package at [github.com/swansonk14/admet_ai](https://github.com/swansonk14/admet_ai).
|
40 |
+
- A more extensive list of successful Chemprop applications is given in our [2023 paper](https://doi.org/10.1021/acs.jcim.3c01250)
|
41 |
+
|
42 |
+
## Version 1.x
|
43 |
+
|
44 |
+
For users who have not yet made the switch to Chemprop v2.0, please reference the following resources.
|
45 |
+
|
46 |
+
### v1 Documentation
|
47 |
+
|
48 |
+
- Documentation of Chemprop v1 is available [here](https://chemprop.readthedocs.io/en/v1.7.1/). Note that the content of this site is several versions behind the final v1 release (v1.7.1) and does not cover the full scope of features available in chemprop v1.
|
49 |
+
- The v1 [README](https://github.com/chemprop/chemprop/blob/v1.7.1/README.md) is the best source for documentation on more recently-added features.
|
50 |
+
- Please also see descriptions of all the possible command line arguments in the v1 [`args.py`](https://github.com/chemprop/chemprop/blob/v1.7.1/chemprop/args.py) file.
|
51 |
+
|
52 |
+
### v1 Tutorials and Examples
|
53 |
+
|
54 |
+
- [Benchmark scripts](https://github.com/chemprop/chemprop_benchmark) - scripts from our 2023 paper, providing examples of many features using Chemprop v1.6.1
|
55 |
+
- [ACS Fall 2023 Workshop](https://github.com/chemprop/chemprop-workshop-acs-fall2023) - presentation, interactive demo, exercises on Google Colab with solution key
|
56 |
+
- [Google Colab notebook](https://colab.research.google.com/github/chemprop/chemprop/blob/v1.7.1/colab_demo.ipynb) - several examples, intended to be run in Google Colab rather than as a Jupyter notebook on your local machine
|
57 |
+
- [nanoHUB tool](https://nanohub.org/resources/chempropdemo/) - a notebook of examples similar to the Colab notebook above, doesn't require any installation
|
58 |
+
- [YouTube video](https://www.youtube.com/watch?v=TeOl5E8Wo2M) - lecture accompanying nanoHUB tool
|
59 |
+
- These [slides](https://docs.google.com/presentation/d/14pbd9LTXzfPSJHyXYkfLxnK8Q80LhVnjImg8a3WqCRM/edit?usp=sharing) provide a Chemprop tutorial and highlight additions as of April 28th, 2020
|
60 |
+
|
61 |
+
### v1 Known Issues
|
62 |
+
|
63 |
+
We have discontinued support for v1 since v2 has been released, but we still appreciate v1 bug reports and will tag them as [`v1-wontfix`](https://github.com/chemprop/chemprop/issues?q=label%3Av1-wontfix+) so the community can find them easily.
|
chemprop-updated/chemprop/__init__.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from . import data, exceptions, featurizers, models, nn, schedulers, utils
|
2 |
+
|
3 |
+
__all__ = ["data", "featurizers", "models", "nn", "utils", "exceptions", "schedulers"]
|
4 |
+
|
5 |
+
__version__ = "2.1.2"
|
chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (743 Bytes). View file
|
|
chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc
ADDED
Binary file (33.7 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc
ADDED
Binary file (430 Bytes). View file
|
|
chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc
ADDED
Binary file (11.1 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc
ADDED
Binary file (6.15 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc
ADDED
Binary file (14.2 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc
ADDED
Binary file (3.12 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc
ADDED
Binary file (8.13 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc
ADDED
Binary file (1.43 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc
ADDED
Binary file (2.82 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc
ADDED
Binary file (11.4 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc
ADDED
Binary file (5.1 kB). View file
|
|
chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc
ADDED
Binary file (26.5 kB). View file
|
|
chemprop-updated/chemprop/cli/common.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
from chemprop.cli.utils import LookupAction
|
6 |
+
from chemprop.cli.utils.args import uppercase
|
7 |
+
from chemprop.featurizers import AtomFeatureMode, MoleculeFeaturizerRegistry, RxnMode
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
def add_common_args(parser: ArgumentParser) -> ArgumentParser:
|
13 |
+
data_args = parser.add_argument_group("Shared input data args")
|
14 |
+
data_args.add_argument(
|
15 |
+
"-s",
|
16 |
+
"--smiles-columns",
|
17 |
+
nargs="+",
|
18 |
+
help="Column names in the input CSV containing SMILES strings (uses the 0th column by default)",
|
19 |
+
)
|
20 |
+
data_args.add_argument(
|
21 |
+
"-r",
|
22 |
+
"--reaction-columns",
|
23 |
+
nargs="+",
|
24 |
+
help="Column names in the input CSV containing reaction SMILES in the format ``REACTANT>AGENT>PRODUCT``, where 'AGENT' is optional",
|
25 |
+
)
|
26 |
+
data_args.add_argument(
|
27 |
+
"--no-header-row",
|
28 |
+
action="store_true",
|
29 |
+
help="Turn off using the first row in the input CSV as column names",
|
30 |
+
)
|
31 |
+
|
32 |
+
dataloader_args = parser.add_argument_group("Dataloader args")
|
33 |
+
dataloader_args.add_argument(
|
34 |
+
"-n",
|
35 |
+
"--num-workers",
|
36 |
+
type=int,
|
37 |
+
default=0,
|
38 |
+
help="""Number of workers for parallel data loading where 0 means sequential
|
39 |
+
(Warning: setting ``num_workers`` to a value greater than 0 can cause hangs on Windows and MacOS)""",
|
40 |
+
)
|
41 |
+
dataloader_args.add_argument("-b", "--batch-size", type=int, default=64, help="Batch size")
|
42 |
+
|
43 |
+
parser.add_argument(
|
44 |
+
"--accelerator", default="auto", help="Passed directly to the lightning ``Trainer()``"
|
45 |
+
)
|
46 |
+
parser.add_argument(
|
47 |
+
"--devices",
|
48 |
+
default="auto",
|
49 |
+
help="Passed directly to the lightning ``Trainer()`` (must be a single string of comma separated devices, e.g. '1, 2' if specifying multiple devices)",
|
50 |
+
)
|
51 |
+
|
52 |
+
featurization_args = parser.add_argument_group("Featurization args")
|
53 |
+
featurization_args.add_argument(
|
54 |
+
"--rxn-mode",
|
55 |
+
"--reaction-mode",
|
56 |
+
type=uppercase,
|
57 |
+
default="REAC_DIFF",
|
58 |
+
choices=list(RxnMode.keys()),
|
59 |
+
help="""Choices for construction of atom and bond features for reactions (case insensitive):
|
60 |
+
|
61 |
+
- ``REAC_PROD``: concatenates the reactants feature with the products feature
|
62 |
+
- ``REAC_DIFF``: concatenates the reactants feature with the difference in features between reactants and products (Default)
|
63 |
+
- ``PROD_DIFF``: concatenates the products feature with the difference in features between reactants and products
|
64 |
+
- ``REAC_PROD_BALANCE``: concatenates the reactants feature with the products feature, balances imbalanced reactions
|
65 |
+
- ``REAC_DIFF_BALANCE``: concatenates the reactants feature with the difference in features between reactants and products, balances imbalanced reactions
|
66 |
+
- ``PROD_DIFF_BALANCE``: concatenates the products feature with the difference in features between reactants and products, balances imbalanced reactions""",
|
67 |
+
)
|
68 |
+
# TODO: Update documenation for multi_hot_atom_featurizer_mode
|
69 |
+
featurization_args.add_argument(
|
70 |
+
"--multi-hot-atom-featurizer-mode",
|
71 |
+
type=uppercase,
|
72 |
+
default="V2",
|
73 |
+
choices=list(AtomFeatureMode.keys()),
|
74 |
+
help="""Choices for multi-hot atom featurization scheme. This will affect both non-reaction and reaction feturization (case insensitive):
|
75 |
+
|
76 |
+
- ``V1``: Corresponds to the original configuration employed in the Chemprop V1
|
77 |
+
- ``V2``: Tailored for a broad range of molecules, this configuration encompasses all elements in the first four rows of the periodic table, along with iodine. It is the default in Chemprop V2.
|
78 |
+
- ``ORGANIC``: This configuration is designed specifically for use with organic molecules for drug research and development and includes a subset of elements most common in organic chemistry, including H, B, C, N, O, F, Si, P, S, Cl, Br, and I.
|
79 |
+
- ``RIGR``: Modified V2 (default) featurizer using only the resonance-invariant atom and bond features.""",
|
80 |
+
)
|
81 |
+
featurization_args.add_argument(
|
82 |
+
"--keep-h",
|
83 |
+
action="store_true",
|
84 |
+
help="Whether hydrogens explicitly specified in input should be kept in the mol graph",
|
85 |
+
)
|
86 |
+
featurization_args.add_argument(
|
87 |
+
"--add-h", action="store_true", help="Whether hydrogens should be added to the mol graph"
|
88 |
+
)
|
89 |
+
featurization_args.add_argument(
|
90 |
+
"--molecule-featurizers",
|
91 |
+
"--features-generators",
|
92 |
+
nargs="+",
|
93 |
+
action=LookupAction(MoleculeFeaturizerRegistry),
|
94 |
+
help="Method(s) of generating molecule features to use as extra descriptors",
|
95 |
+
)
|
96 |
+
# TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
|
97 |
+
# featurization_args.add_argument(
|
98 |
+
# "--features-generators", nargs="+", help="Renamed to `--molecule-featurizers`."
|
99 |
+
# )
|
100 |
+
featurization_args.add_argument(
|
101 |
+
"--descriptors-path",
|
102 |
+
type=Path,
|
103 |
+
help="Path to extra descriptors to concatenate to learned representation",
|
104 |
+
)
|
105 |
+
# TODO: Add in v2.1
|
106 |
+
# featurization_args.add_argument(
|
107 |
+
# "--phase-features-path",
|
108 |
+
# help="Path to features used to indicate the phase of the data in one-hot vector form. Used in spectra datatype.",
|
109 |
+
# )
|
110 |
+
featurization_args.add_argument(
|
111 |
+
"--no-descriptor-scaling", action="store_true", help="Turn off extra descriptor scaling"
|
112 |
+
)
|
113 |
+
featurization_args.add_argument(
|
114 |
+
"--no-atom-feature-scaling", action="store_true", help="Turn off extra atom feature scaling"
|
115 |
+
)
|
116 |
+
featurization_args.add_argument(
|
117 |
+
"--no-atom-descriptor-scaling",
|
118 |
+
action="store_true",
|
119 |
+
help="Turn off extra atom descriptor scaling",
|
120 |
+
)
|
121 |
+
featurization_args.add_argument(
|
122 |
+
"--no-bond-feature-scaling", action="store_true", help="Turn off extra bond feature scaling"
|
123 |
+
)
|
124 |
+
featurization_args.add_argument(
|
125 |
+
"--atom-features-path",
|
126 |
+
nargs="+",
|
127 |
+
action="append",
|
128 |
+
help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom features to supply before message passing (e.g., ``--atom-features-path 0 /path/to/features_0.npz``) indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-features-path [...] --atom-features-path [...]``).",
|
129 |
+
)
|
130 |
+
featurization_args.add_argument(
|
131 |
+
"--atom-descriptors-path",
|
132 |
+
nargs="+",
|
133 |
+
action="append",
|
134 |
+
help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom descriptors to supply after message passing (e.g., ``--atom-descriptors-path 0 /path/to/descriptors_0.npz`` indicates that the descriptors at the given path should be supplied to the 0-th component. To supply additional descriptors for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-descriptors-path [...] --atom-descriptors-path [...]``).",
|
135 |
+
)
|
136 |
+
featurization_args.add_argument(
|
137 |
+
"--bond-features-path",
|
138 |
+
nargs="+",
|
139 |
+
action="append",
|
140 |
+
help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional bond features to supply before message passing (e.g., ``--bond-features-path 0 /path/to/features_0.npz`` indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--bond-features-path [...] --bond-features-path [...]``).",
|
141 |
+
)
|
142 |
+
# TODO: Add in v2.2
|
143 |
+
# parser.add_argument(
|
144 |
+
# "--constraints-path",
|
145 |
+
# help="Path to constraints applied to atomic/bond properties prediction.",
|
146 |
+
# )
|
147 |
+
|
148 |
+
return parser
|
149 |
+
|
150 |
+
|
151 |
+
def process_common_args(args: Namespace) -> Namespace:
|
152 |
+
# TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
|
153 |
+
# if args.features_generators is not None:
|
154 |
+
# raise ArgumentError(
|
155 |
+
# argument=None,
|
156 |
+
# message="`--features-generators` has been renamed to `--molecule-featurizers`.",
|
157 |
+
# )
|
158 |
+
|
159 |
+
for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
|
160 |
+
inds_paths = getattr(args, key)
|
161 |
+
|
162 |
+
if not inds_paths:
|
163 |
+
continue
|
164 |
+
|
165 |
+
ind_path_dict = {}
|
166 |
+
|
167 |
+
for ind_path in inds_paths:
|
168 |
+
if len(ind_path) > 2:
|
169 |
+
raise ArgumentError(
|
170 |
+
argument=None,
|
171 |
+
message="Too many arguments were given for atom features/descriptors or bond features. It should be either a two-tuple of molecule index and a path, or a single path (assumed to be the 0-th molecule).",
|
172 |
+
)
|
173 |
+
|
174 |
+
if len(ind_path) == 1:
|
175 |
+
ind = 0
|
176 |
+
path = ind_path[0]
|
177 |
+
else:
|
178 |
+
ind, path = ind_path
|
179 |
+
|
180 |
+
if ind_path_dict.get(int(ind), None):
|
181 |
+
raise ArgumentError(
|
182 |
+
argument=None,
|
183 |
+
message=f"Duplicate atom features/descriptors or bond features given for molecule index {ind}",
|
184 |
+
)
|
185 |
+
|
186 |
+
ind_path_dict[int(ind)] = Path(path)
|
187 |
+
|
188 |
+
setattr(args, key, ind_path_dict)
|
189 |
+
|
190 |
+
return args
|
191 |
+
|
192 |
+
|
193 |
+
def validate_common_args(args):
|
194 |
+
pass
|
195 |
+
|
196 |
+
|
197 |
+
def find_models(model_paths: list[Path]):
|
198 |
+
collected_model_paths = []
|
199 |
+
|
200 |
+
for model_path in model_paths:
|
201 |
+
if model_path.suffix in [".ckpt", ".pt"]:
|
202 |
+
collected_model_paths.append(model_path)
|
203 |
+
elif model_path.is_dir():
|
204 |
+
collected_model_paths.extend(list(model_path.rglob("*.pt")))
|
205 |
+
else:
|
206 |
+
raise ArgumentError(
|
207 |
+
argument=None,
|
208 |
+
message=f"Expected a .ckpt or .pt file, or a directory. Got {model_path}",
|
209 |
+
)
|
210 |
+
|
211 |
+
return collected_model_paths
|
chemprop-updated/chemprop/cli/conf.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
LOG_DIR = Path(os.getenv("CHEMPROP_LOG_DIR", "chemprop_logs"))
|
7 |
+
LOG_LEVELS = {0: logging.INFO, 1: logging.DEBUG, -1: logging.WARNING, -2: logging.ERROR}
|
8 |
+
NOW = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
9 |
+
CHEMPROP_TRAIN_DIR = Path(os.getenv("CHEMPROP_TRAIN_DIR", "chemprop_training"))
|
chemprop-updated/chemprop/cli/convert.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
|
6 |
+
from chemprop.cli.utils import Subcommand
|
7 |
+
from chemprop.utils.v1_to_v2 import convert_model_file_v1_to_v2
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
|
12 |
+
class ConvertSubcommand(Subcommand):
|
13 |
+
COMMAND = "convert"
|
14 |
+
HELP = "Convert a v1 model checkpoint (.pt) to a v2 model checkpoint (.pt)."
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
18 |
+
parser.add_argument(
|
19 |
+
"-i",
|
20 |
+
"--input-path",
|
21 |
+
required=True,
|
22 |
+
type=Path,
|
23 |
+
help="Path to a v1 model .pt checkpoint file",
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"-o",
|
27 |
+
"--output-path",
|
28 |
+
type=Path,
|
29 |
+
help="Path to which the converted model will be saved (``CURRENT_DIRECTORY/STEM_OF_INPUT_v2.pt`` by default)",
|
30 |
+
)
|
31 |
+
return parser
|
32 |
+
|
33 |
+
@classmethod
|
34 |
+
def func(cls, args: Namespace):
|
35 |
+
if args.output_path is None:
|
36 |
+
args.output_path = Path(args.input_path.stem + "_v2.pt")
|
37 |
+
if args.output_path.suffix != ".pt":
|
38 |
+
raise ArgumentError(
|
39 |
+
argument=None, message=f"Output must be a `.pt` file. Got {args.output_path}"
|
40 |
+
)
|
41 |
+
|
42 |
+
logger.info(
|
43 |
+
f"Converting v1 model checkpoint '{args.input_path}' to v2 model checkpoint '{args.output_path}'..."
|
44 |
+
)
|
45 |
+
convert_model_file_v1_to_v2(args.input_path, args.output_path)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
parser = ArgumentParser()
|
50 |
+
parser = ConvertSubcommand.add_args(parser)
|
51 |
+
|
52 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
53 |
+
|
54 |
+
args = parser.parse_args()
|
55 |
+
ConvertSubcommand.func(args)
|
chemprop-updated/chemprop/cli/fingerprint.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from chemprop import data
|
11 |
+
from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
|
12 |
+
from chemprop.cli.predict import find_models
|
13 |
+
from chemprop.cli.utils import Subcommand, build_data_from_files, make_dataset
|
14 |
+
from chemprop.models import load_model
|
15 |
+
from chemprop.nn.metrics import LossFunctionRegistry
|
16 |
+
|
17 |
+
logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
class FingerprintSubcommand(Subcommand):
|
21 |
+
COMMAND = "fingerprint"
|
22 |
+
HELP = "Use a pretrained chemprop model to calculate learned representations."
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
26 |
+
parser = add_common_args(parser)
|
27 |
+
parser.add_argument(
|
28 |
+
"-i",
|
29 |
+
"--test-path",
|
30 |
+
required=True,
|
31 |
+
type=Path,
|
32 |
+
help="Path to an input CSV file containing SMILES",
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"-o",
|
36 |
+
"--output",
|
37 |
+
"--preds-path",
|
38 |
+
type=Path,
|
39 |
+
help="Specify the path where predictions will be saved. If the file extension is .npz, they will be saved as a npz file. Otherwise, the predictions will be saved as a CSV. The index of the model will be appended to the filename's stem. By default, predictions will be saved to the same location as ``--test-path`` with '_fps' appended (e.g., 'PATH/TO/TEST_PATH_fps_0.csv').",
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--model-paths",
|
43 |
+
"--model-path",
|
44 |
+
required=True,
|
45 |
+
type=Path,
|
46 |
+
nargs="+",
|
47 |
+
help="Specify location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, chemprop will recursively search and predict on all found (.pt) models.",
|
48 |
+
)
|
49 |
+
parser.add_argument(
|
50 |
+
"--ffn-block-index",
|
51 |
+
required=True,
|
52 |
+
type=int,
|
53 |
+
default=-1,
|
54 |
+
help="The index indicates which linear layer returns the encoding in the FFN. An index of 0 denotes the post-aggregation representation through a 0-layer MLP, while an index of 1 represents the output from the first linear layer in the FFN, and so forth.",
|
55 |
+
)
|
56 |
+
|
57 |
+
return parser
|
58 |
+
|
59 |
+
@classmethod
|
60 |
+
def func(cls, args: Namespace):
|
61 |
+
args = process_common_args(args)
|
62 |
+
validate_common_args(args)
|
63 |
+
args = process_fingerprint_args(args)
|
64 |
+
main(args)
|
65 |
+
|
66 |
+
|
67 |
+
def process_fingerprint_args(args: Namespace) -> Namespace:
|
68 |
+
if args.test_path.suffix not in [".csv"]:
|
69 |
+
raise ArgumentError(
|
70 |
+
argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
|
71 |
+
)
|
72 |
+
if args.output is None:
|
73 |
+
args.output = args.test_path.parent / (args.test_path.stem + "_fps.csv")
|
74 |
+
if args.output.suffix not in [".csv", ".npz"]:
|
75 |
+
raise ArgumentError(
|
76 |
+
argument=None, message=f"Output must be a CSV or NPZ file. Got '{args.output}'."
|
77 |
+
)
|
78 |
+
return args
|
79 |
+
|
80 |
+
|
81 |
+
def make_fingerprint_for_model(
|
82 |
+
args: Namespace, model_path: Path, multicomponent: bool, output_path: Path
|
83 |
+
):
|
84 |
+
model = load_model(model_path, multicomponent)
|
85 |
+
model.eval()
|
86 |
+
|
87 |
+
bounded = any(
|
88 |
+
isinstance(model.criterion, LossFunctionRegistry[loss_function])
|
89 |
+
for loss_function in LossFunctionRegistry.keys()
|
90 |
+
if "bounded" in loss_function
|
91 |
+
)
|
92 |
+
|
93 |
+
format_kwargs = dict(
|
94 |
+
no_header_row=args.no_header_row,
|
95 |
+
smiles_cols=args.smiles_columns,
|
96 |
+
rxn_cols=args.reaction_columns,
|
97 |
+
target_cols=[],
|
98 |
+
ignore_cols=None,
|
99 |
+
splits_col=None,
|
100 |
+
weight_col=None,
|
101 |
+
bounded=bounded,
|
102 |
+
)
|
103 |
+
|
104 |
+
featurization_kwargs = dict(
|
105 |
+
molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
|
106 |
+
)
|
107 |
+
|
108 |
+
test_data = build_data_from_files(
|
109 |
+
args.test_path,
|
110 |
+
**format_kwargs,
|
111 |
+
p_descriptors=args.descriptors_path,
|
112 |
+
p_atom_feats=args.atom_features_path,
|
113 |
+
p_bond_feats=args.bond_features_path,
|
114 |
+
p_atom_descs=args.atom_descriptors_path,
|
115 |
+
**featurization_kwargs,
|
116 |
+
)
|
117 |
+
logger.info(f"test size: {len(test_data[0])}")
|
118 |
+
test_dsets = [
|
119 |
+
make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in test_data
|
120 |
+
]
|
121 |
+
|
122 |
+
if multicomponent:
|
123 |
+
test_dset = data.MulticomponentDataset(test_dsets)
|
124 |
+
else:
|
125 |
+
test_dset = test_dsets[0]
|
126 |
+
|
127 |
+
test_loader = data.build_dataloader(test_dset, args.batch_size, args.num_workers, shuffle=False)
|
128 |
+
|
129 |
+
logger.info(model)
|
130 |
+
|
131 |
+
with torch.no_grad():
|
132 |
+
if multicomponent:
|
133 |
+
encodings = [
|
134 |
+
model.encoding(batch.bmgs, batch.V_ds, batch.X_d, args.ffn_block_index)
|
135 |
+
for batch in test_loader
|
136 |
+
]
|
137 |
+
else:
|
138 |
+
encodings = [
|
139 |
+
model.encoding(batch.bmg, batch.V_d, batch.X_d, args.ffn_block_index)
|
140 |
+
for batch in test_loader
|
141 |
+
]
|
142 |
+
H = torch.cat(encodings, 0).numpy()
|
143 |
+
|
144 |
+
if output_path.suffix in [".npz"]:
|
145 |
+
np.savez(output_path, H=H)
|
146 |
+
elif output_path.suffix == ".csv":
|
147 |
+
fingerprint_columns = [f"fp_{i}" for i in range(H.shape[1])]
|
148 |
+
df_fingerprints = pd.DataFrame(H, columns=fingerprint_columns)
|
149 |
+
df_fingerprints.to_csv(output_path, index=False)
|
150 |
+
else:
|
151 |
+
raise ArgumentError(
|
152 |
+
argument=None, message=f"Output must be a CSV or npz file. Got {args.output}."
|
153 |
+
)
|
154 |
+
logger.info(f"Fingerprints saved to '{output_path}'")
|
155 |
+
|
156 |
+
|
157 |
+
def main(args):
|
158 |
+
match (args.smiles_columns, args.reaction_columns):
|
159 |
+
case [None, None]:
|
160 |
+
n_components = 1
|
161 |
+
case [_, None]:
|
162 |
+
n_components = len(args.smiles_columns)
|
163 |
+
case [None, _]:
|
164 |
+
n_components = len(args.reaction_columns)
|
165 |
+
case _:
|
166 |
+
n_components = len(args.smiles_columns) + len(args.reaction_columns)
|
167 |
+
|
168 |
+
multicomponent = n_components > 1
|
169 |
+
|
170 |
+
for i, model_path in enumerate(find_models(args.model_paths)):
|
171 |
+
logger.info(f"Fingerprints with model {i} at '{model_path}'")
|
172 |
+
output_path = args.output.parent / f"{args.output.stem}_{i}{args.output.suffix}"
|
173 |
+
make_fingerprint_for_model(args, model_path, multicomponent, output_path)
|
174 |
+
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
parser = ArgumentParser()
|
178 |
+
parser = FingerprintSubcommand.add_args(parser)
|
179 |
+
|
180 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
181 |
+
args = parser.parse_args()
|
182 |
+
args = FingerprintSubcommand.func(args)
|
chemprop-updated/chemprop/cli/hpopt.py
ADDED
@@ -0,0 +1,537 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import shutil
|
5 |
+
import sys
|
6 |
+
|
7 |
+
from configargparse import ArgumentParser, Namespace
|
8 |
+
from lightning import pytorch as pl
|
9 |
+
from lightning.pytorch.callbacks import EarlyStopping
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
|
14 |
+
from chemprop.cli.train import (
|
15 |
+
TrainSubcommand,
|
16 |
+
add_train_args,
|
17 |
+
build_datasets,
|
18 |
+
build_model,
|
19 |
+
build_splits,
|
20 |
+
normalize_inputs,
|
21 |
+
process_train_args,
|
22 |
+
save_config,
|
23 |
+
validate_train_args,
|
24 |
+
)
|
25 |
+
from chemprop.cli.utils.command import Subcommand
|
26 |
+
from chemprop.data import build_dataloader
|
27 |
+
from chemprop.nn import AggregationRegistry, MetricRegistry
|
28 |
+
from chemprop.nn.transforms import UnscaleTransform
|
29 |
+
from chemprop.nn.utils import Activation
|
30 |
+
|
31 |
+
NO_RAY = False
|
32 |
+
DEFAULT_SEARCH_SPACE = {
|
33 |
+
"activation": None,
|
34 |
+
"aggregation": None,
|
35 |
+
"aggregation_norm": None,
|
36 |
+
"batch_size": None,
|
37 |
+
"depth": None,
|
38 |
+
"dropout": None,
|
39 |
+
"ffn_hidden_dim": None,
|
40 |
+
"ffn_num_layers": None,
|
41 |
+
"final_lr_ratio": None,
|
42 |
+
"message_hidden_dim": None,
|
43 |
+
"init_lr_ratio": None,
|
44 |
+
"max_lr": None,
|
45 |
+
"warmup_epochs": None,
|
46 |
+
}
|
47 |
+
|
48 |
+
try:
|
49 |
+
import ray
|
50 |
+
from ray import tune
|
51 |
+
from ray.train import CheckpointConfig, RunConfig, ScalingConfig
|
52 |
+
from ray.train.lightning import (
|
53 |
+
RayDDPStrategy,
|
54 |
+
RayLightningEnvironment,
|
55 |
+
RayTrainReportCallback,
|
56 |
+
prepare_trainer,
|
57 |
+
)
|
58 |
+
from ray.train.torch import TorchTrainer
|
59 |
+
from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
|
60 |
+
|
61 |
+
DEFAULT_SEARCH_SPACE = {
|
62 |
+
"activation": tune.choice(categories=list(Activation.keys())),
|
63 |
+
"aggregation": tune.choice(categories=list(AggregationRegistry.keys())),
|
64 |
+
"aggregation_norm": tune.quniform(lower=1, upper=200, q=1),
|
65 |
+
"batch_size": tune.choice([16, 32, 64, 128, 256]),
|
66 |
+
"depth": tune.qrandint(lower=2, upper=6, q=1),
|
67 |
+
"dropout": tune.choice([0.0] * 8 + list(np.arange(0.05, 0.45, 0.05))),
|
68 |
+
"ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
|
69 |
+
"ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
|
70 |
+
"final_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
|
71 |
+
"message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
|
72 |
+
"init_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
|
73 |
+
"max_lr": tune.loguniform(lower=1e-4, upper=1e-2),
|
74 |
+
"warmup_epochs": None,
|
75 |
+
}
|
76 |
+
except ImportError:
|
77 |
+
NO_RAY = True
|
78 |
+
|
79 |
+
NO_HYPEROPT = False
|
80 |
+
try:
|
81 |
+
from ray.tune.search.hyperopt import HyperOptSearch
|
82 |
+
except ImportError:
|
83 |
+
NO_HYPEROPT = True
|
84 |
+
|
85 |
+
NO_OPTUNA = False
|
86 |
+
try:
|
87 |
+
from ray.tune.search.optuna import OptunaSearch
|
88 |
+
except ImportError:
|
89 |
+
NO_OPTUNA = True
|
90 |
+
|
91 |
+
|
92 |
+
logger = logging.getLogger(__name__)
|
93 |
+
|
94 |
+
SEARCH_SPACE = DEFAULT_SEARCH_SPACE
|
95 |
+
|
96 |
+
SEARCH_PARAM_KEYWORDS_MAP = {
|
97 |
+
"basic": ["depth", "ffn_num_layers", "dropout", "ffn_hidden_dim", "message_hidden_dim"],
|
98 |
+
"learning_rate": ["max_lr", "init_lr_ratio", "final_lr_ratio", "warmup_epochs"],
|
99 |
+
"all": list(DEFAULT_SEARCH_SPACE.keys()),
|
100 |
+
"init_lr": ["init_lr_ratio"],
|
101 |
+
"final_lr": ["final_lr_ratio"],
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
class HpoptSubcommand(Subcommand):
|
106 |
+
COMMAND = "hpopt"
|
107 |
+
HELP = "Perform hyperparameter optimization on the given task."
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
111 |
+
parser = add_common_args(parser)
|
112 |
+
parser = add_train_args(parser)
|
113 |
+
return add_hpopt_args(parser)
|
114 |
+
|
115 |
+
@classmethod
|
116 |
+
def func(cls, args: Namespace):
|
117 |
+
args = process_common_args(args)
|
118 |
+
args = process_train_args(args)
|
119 |
+
args = process_hpopt_args(args)
|
120 |
+
validate_common_args(args)
|
121 |
+
validate_train_args(args)
|
122 |
+
main(args)
|
123 |
+
|
124 |
+
|
125 |
+
def add_hpopt_args(parser: ArgumentParser) -> ArgumentParser:
|
126 |
+
hpopt_args = parser.add_argument_group("Chemprop hyperparameter optimization arguments")
|
127 |
+
|
128 |
+
hpopt_args.add_argument(
|
129 |
+
"--search-parameter-keywords",
|
130 |
+
type=str,
|
131 |
+
nargs="+",
|
132 |
+
default=["basic"],
|
133 |
+
help=f"""The model parameters over which to search for an optimal hyperparameter configuration. Some options are bundles of parameters or otherwise special parameter operations. Special keywords include:
|
134 |
+
- ``basic``: Default set of hyperparameters for search (depth, ffn_num_layers, dropout, message_hidden_dim, and ffn_hidden_dim)
|
135 |
+
- ``learning_rate``: Search for max_lr, init_lr_ratio, final_lr_ratio, and warmup_epochs. The search for init_lr and final_lr values are defined as fractions of the max_lr value. The search for warmup_epochs is as a fraction of the total epochs used.
|
136 |
+
- ``all``: Include search for all 13 individual keyword options (including: activation, aggregation, aggregation_norm, and batch_size which aren't included in the other two keywords).
|
137 |
+
Individual supported parameters:
|
138 |
+
{list(DEFAULT_SEARCH_SPACE.keys())}
|
139 |
+
""",
|
140 |
+
)
|
141 |
+
|
142 |
+
hpopt_args.add_argument(
|
143 |
+
"--hpopt-save-dir",
|
144 |
+
type=Path,
|
145 |
+
help="Directory to save the hyperparameter optimization results",
|
146 |
+
)
|
147 |
+
|
148 |
+
raytune_args = parser.add_argument_group("Ray Tune arguments")
|
149 |
+
|
150 |
+
raytune_args.add_argument(
|
151 |
+
"--raytune-num-samples",
|
152 |
+
type=int,
|
153 |
+
default=10,
|
154 |
+
help="Passed directly to Ray Tune ``TuneConfig`` to control number of trials to run",
|
155 |
+
)
|
156 |
+
|
157 |
+
raytune_args.add_argument(
|
158 |
+
"--raytune-search-algorithm",
|
159 |
+
choices=["random", "hyperopt", "optuna"],
|
160 |
+
default="hyperopt",
|
161 |
+
help="Passed to Ray Tune ``TuneConfig`` to control search algorithm",
|
162 |
+
)
|
163 |
+
|
164 |
+
raytune_args.add_argument(
|
165 |
+
"--raytune-trial-scheduler",
|
166 |
+
choices=["FIFO", "AsyncHyperBand"],
|
167 |
+
default="FIFO",
|
168 |
+
help="Passed to Ray Tune ``TuneConfig`` to control trial scheduler",
|
169 |
+
)
|
170 |
+
|
171 |
+
raytune_args.add_argument(
|
172 |
+
"--raytune-num-workers",
|
173 |
+
type=int,
|
174 |
+
default=1,
|
175 |
+
help="Passed directly to Ray Tune ``ScalingConfig`` to control number of workers to use",
|
176 |
+
)
|
177 |
+
|
178 |
+
raytune_args.add_argument(
|
179 |
+
"--raytune-use-gpu",
|
180 |
+
action="store_true",
|
181 |
+
help="Passed directly to Ray Tune ``ScalingConfig`` to control whether to use GPUs",
|
182 |
+
)
|
183 |
+
|
184 |
+
raytune_args.add_argument(
|
185 |
+
"--raytune-num-checkpoints-to-keep",
|
186 |
+
type=int,
|
187 |
+
default=1,
|
188 |
+
help="Passed directly to Ray Tune ``CheckpointConfig`` to control number of checkpoints to keep",
|
189 |
+
)
|
190 |
+
|
191 |
+
raytune_args.add_argument(
|
192 |
+
"--raytune-grace-period",
|
193 |
+
type=int,
|
194 |
+
default=10,
|
195 |
+
help="Passed directly to Ray Tune ``ASHAScheduler`` to control grace period",
|
196 |
+
)
|
197 |
+
|
198 |
+
raytune_args.add_argument(
|
199 |
+
"--raytune-reduction-factor",
|
200 |
+
type=int,
|
201 |
+
default=2,
|
202 |
+
help="Passed directly to Ray Tune ``ASHAScheduler`` to control reduction factor",
|
203 |
+
)
|
204 |
+
|
205 |
+
raytune_args.add_argument(
|
206 |
+
"--raytune-temp-dir", help="Passed directly to Ray Tune init to control temporary directory"
|
207 |
+
)
|
208 |
+
|
209 |
+
raytune_args.add_argument(
|
210 |
+
"--raytune-num-cpus",
|
211 |
+
type=int,
|
212 |
+
help="Passed directly to Ray Tune init to control number of CPUs to use",
|
213 |
+
)
|
214 |
+
|
215 |
+
raytune_args.add_argument(
|
216 |
+
"--raytune-num-gpus",
|
217 |
+
type=int,
|
218 |
+
help="Passed directly to Ray Tune init to control number of GPUs to use",
|
219 |
+
)
|
220 |
+
|
221 |
+
raytune_args.add_argument(
|
222 |
+
"--raytune-max-concurrent-trials",
|
223 |
+
type=int,
|
224 |
+
help="Passed directly to Ray Tune TuneConfig to control maximum concurrent trials",
|
225 |
+
)
|
226 |
+
|
227 |
+
hyperopt_args = parser.add_argument_group("Hyperopt arguments")
|
228 |
+
|
229 |
+
hyperopt_args.add_argument(
|
230 |
+
"--hyperopt-n-initial-points",
|
231 |
+
type=int,
|
232 |
+
help="Passed directly to ``HyperOptSearch`` to control number of initial points to sample",
|
233 |
+
)
|
234 |
+
|
235 |
+
hyperopt_args.add_argument(
|
236 |
+
"--hyperopt-random-state-seed",
|
237 |
+
type=int,
|
238 |
+
default=None,
|
239 |
+
help="Passed directly to ``HyperOptSearch`` to control random state seed",
|
240 |
+
)
|
241 |
+
|
242 |
+
return parser
|
243 |
+
|
244 |
+
|
245 |
+
def process_hpopt_args(args: Namespace) -> Namespace:
|
246 |
+
if args.hpopt_save_dir is None:
|
247 |
+
args.hpopt_save_dir = Path(f"chemprop_hpopt/{args.data_path.stem}")
|
248 |
+
|
249 |
+
args.hpopt_save_dir.mkdir(exist_ok=True, parents=True)
|
250 |
+
|
251 |
+
search_parameters = set()
|
252 |
+
|
253 |
+
available_search_parameters = list(SEARCH_SPACE.keys()) + list(SEARCH_PARAM_KEYWORDS_MAP.keys())
|
254 |
+
|
255 |
+
for keyword in args.search_parameter_keywords:
|
256 |
+
if keyword not in available_search_parameters:
|
257 |
+
raise ValueError(
|
258 |
+
f"Search parameter keyword: {keyword} not in available options: {available_search_parameters}."
|
259 |
+
)
|
260 |
+
|
261 |
+
search_parameters.update(
|
262 |
+
SEARCH_PARAM_KEYWORDS_MAP[keyword]
|
263 |
+
if keyword in SEARCH_PARAM_KEYWORDS_MAP
|
264 |
+
else [keyword]
|
265 |
+
)
|
266 |
+
|
267 |
+
args.search_parameter_keywords = list(search_parameters)
|
268 |
+
|
269 |
+
if not args.hyperopt_n_initial_points:
|
270 |
+
args.hyperopt_n_initial_points = args.raytune_num_samples // 2
|
271 |
+
|
272 |
+
return args
|
273 |
+
|
274 |
+
|
275 |
+
def build_search_space(search_parameters: list[str], train_epochs: int) -> dict:
|
276 |
+
if "warmup_epochs" in search_parameters and SEARCH_SPACE.get("warmup_epochs", None) is None:
|
277 |
+
assert (
|
278 |
+
train_epochs >= 6
|
279 |
+
), "Training epochs must be at least 6 to perform hyperparameter optimization for warmup_epochs."
|
280 |
+
SEARCH_SPACE["warmup_epochs"] = tune.qrandint(lower=1, upper=train_epochs // 2, q=1)
|
281 |
+
|
282 |
+
return {param: SEARCH_SPACE[param] for param in search_parameters}
|
283 |
+
|
284 |
+
|
285 |
+
def update_args_with_config(args: Namespace, config: dict) -> Namespace:
|
286 |
+
args = deepcopy(args)
|
287 |
+
|
288 |
+
for key, value in config.items():
|
289 |
+
match key:
|
290 |
+
case "final_lr_ratio":
|
291 |
+
setattr(args, "final_lr", value * config.get("max_lr", args.max_lr))
|
292 |
+
|
293 |
+
case "init_lr_ratio":
|
294 |
+
setattr(args, "init_lr", value * config.get("max_lr", args.max_lr))
|
295 |
+
|
296 |
+
case _:
|
297 |
+
assert key in args, f"Key: {key} not found in args."
|
298 |
+
setattr(args, key, value)
|
299 |
+
|
300 |
+
return args
|
301 |
+
|
302 |
+
|
303 |
+
def train_model(config, args, train_dset, val_dset, logger, output_transform, input_transforms):
|
304 |
+
args = update_args_with_config(args, config)
|
305 |
+
|
306 |
+
train_loader = build_dataloader(
|
307 |
+
train_dset, args.batch_size, args.num_workers, seed=args.data_seed
|
308 |
+
)
|
309 |
+
val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
|
310 |
+
|
311 |
+
seed = args.pytorch_seed if args.pytorch_seed is not None else torch.seed()
|
312 |
+
|
313 |
+
torch.manual_seed(seed)
|
314 |
+
|
315 |
+
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
|
316 |
+
logger.info(model)
|
317 |
+
|
318 |
+
if args.tracking_metric == "val_loss":
|
319 |
+
T_tracking_metric = model.criterion.__class__
|
320 |
+
else:
|
321 |
+
T_tracking_metric = MetricRegistry[args.tracking_metric]
|
322 |
+
args.tracking_metric = "val/" + args.tracking_metric
|
323 |
+
|
324 |
+
monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
|
325 |
+
logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
|
326 |
+
|
327 |
+
patience = args.patience if args.patience is not None else args.epochs
|
328 |
+
early_stopping = EarlyStopping(args.tracking_metric, patience=patience, mode=monitor_mode)
|
329 |
+
|
330 |
+
trainer = pl.Trainer(
|
331 |
+
accelerator=args.accelerator,
|
332 |
+
devices=args.devices,
|
333 |
+
max_epochs=args.epochs,
|
334 |
+
gradient_clip_val=args.grad_clip,
|
335 |
+
strategy=RayDDPStrategy(),
|
336 |
+
callbacks=[RayTrainReportCallback(), early_stopping],
|
337 |
+
plugins=[RayLightningEnvironment()],
|
338 |
+
deterministic=args.pytorch_seed is not None,
|
339 |
+
)
|
340 |
+
trainer = prepare_trainer(trainer)
|
341 |
+
trainer.fit(model, train_loader, val_loader)
|
342 |
+
|
343 |
+
|
344 |
+
def tune_model(
|
345 |
+
args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
|
346 |
+
):
|
347 |
+
match args.raytune_trial_scheduler:
|
348 |
+
case "FIFO":
|
349 |
+
scheduler = FIFOScheduler()
|
350 |
+
case "AsyncHyperBand":
|
351 |
+
scheduler = ASHAScheduler(
|
352 |
+
max_t=args.epochs,
|
353 |
+
grace_period=min(args.raytune_grace_period, args.epochs),
|
354 |
+
reduction_factor=args.raytune_reduction_factor,
|
355 |
+
)
|
356 |
+
case _:
|
357 |
+
raise ValueError(f"Invalid trial scheduler! got: {args.raytune_trial_scheduler}.")
|
358 |
+
|
359 |
+
resources_per_worker = {}
|
360 |
+
if args.raytune_num_cpus and args.raytune_max_concurrent_trials:
|
361 |
+
resources_per_worker["CPU"] = args.raytune_num_cpus / args.raytune_max_concurrent_trials
|
362 |
+
if args.raytune_num_gpus and args.raytune_max_concurrent_trials:
|
363 |
+
resources_per_worker["GPU"] = args.raytune_num_gpus / args.raytune_max_concurrent_trials
|
364 |
+
if not resources_per_worker:
|
365 |
+
resources_per_worker = None
|
366 |
+
|
367 |
+
if args.raytune_num_gpus:
|
368 |
+
use_gpu = True
|
369 |
+
else:
|
370 |
+
use_gpu = args.raytune_use_gpu
|
371 |
+
|
372 |
+
scaling_config = ScalingConfig(
|
373 |
+
num_workers=args.raytune_num_workers,
|
374 |
+
use_gpu=use_gpu,
|
375 |
+
resources_per_worker=resources_per_worker,
|
376 |
+
trainer_resources={"CPU": 0},
|
377 |
+
)
|
378 |
+
|
379 |
+
checkpoint_config = CheckpointConfig(
|
380 |
+
num_to_keep=args.raytune_num_checkpoints_to_keep,
|
381 |
+
checkpoint_score_attribute=args.tracking_metric,
|
382 |
+
checkpoint_score_order=monitor_mode,
|
383 |
+
)
|
384 |
+
|
385 |
+
run_config = RunConfig(
|
386 |
+
checkpoint_config=checkpoint_config,
|
387 |
+
storage_path=args.hpopt_save_dir.absolute() / "ray_results",
|
388 |
+
)
|
389 |
+
|
390 |
+
ray_trainer = TorchTrainer(
|
391 |
+
lambda config: train_model(
|
392 |
+
config, args, train_dset, val_dset, logger, output_transform, input_transforms
|
393 |
+
),
|
394 |
+
scaling_config=scaling_config,
|
395 |
+
run_config=run_config,
|
396 |
+
)
|
397 |
+
|
398 |
+
match args.raytune_search_algorithm:
|
399 |
+
case "random":
|
400 |
+
search_alg = None
|
401 |
+
case "hyperopt":
|
402 |
+
if NO_HYPEROPT:
|
403 |
+
raise ImportError(
|
404 |
+
"HyperOptSearch requires hyperopt to be installed. Use 'pip install -U hyperopt' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
|
405 |
+
)
|
406 |
+
|
407 |
+
search_alg = HyperOptSearch(
|
408 |
+
n_initial_points=args.hyperopt_n_initial_points,
|
409 |
+
random_state_seed=args.hyperopt_random_state_seed,
|
410 |
+
)
|
411 |
+
case "optuna":
|
412 |
+
if NO_OPTUNA:
|
413 |
+
raise ImportError(
|
414 |
+
"OptunaSearch requires optuna to be installed. Use 'pip install -U optuna' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
|
415 |
+
)
|
416 |
+
|
417 |
+
search_alg = OptunaSearch()
|
418 |
+
|
419 |
+
tune_config = tune.TuneConfig(
|
420 |
+
metric=args.tracking_metric,
|
421 |
+
mode=monitor_mode,
|
422 |
+
num_samples=args.raytune_num_samples,
|
423 |
+
scheduler=scheduler,
|
424 |
+
search_alg=search_alg,
|
425 |
+
trial_dirname_creator=lambda trial: str(trial.trial_id),
|
426 |
+
)
|
427 |
+
|
428 |
+
tuner = tune.Tuner(
|
429 |
+
ray_trainer,
|
430 |
+
param_space={
|
431 |
+
"train_loop_config": build_search_space(args.search_parameter_keywords, args.epochs)
|
432 |
+
},
|
433 |
+
tune_config=tune_config,
|
434 |
+
)
|
435 |
+
|
436 |
+
return tuner.fit()
|
437 |
+
|
438 |
+
|
439 |
+
def main(args: Namespace):
|
440 |
+
if NO_RAY:
|
441 |
+
raise ImportError(
|
442 |
+
"Ray Tune requires ray to be installed. If you installed Chemprop from PyPI, run 'pip install -U ray[tune]' to install ray. If you installed from source, use 'pip install -e .[hpopt]' in Chemprop folder to install all hpopt relevant packages."
|
443 |
+
)
|
444 |
+
|
445 |
+
if not ray.is_initialized():
|
446 |
+
try:
|
447 |
+
ray.init(
|
448 |
+
_temp_dir=args.raytune_temp_dir,
|
449 |
+
num_cpus=args.raytune_num_cpus,
|
450 |
+
num_gpus=args.raytune_num_gpus,
|
451 |
+
)
|
452 |
+
except OSError as e:
|
453 |
+
if "AF_UNIX path length cannot exceed 107 bytes" in str(e):
|
454 |
+
raise OSError(
|
455 |
+
f"Ray Tune fails due to: {e}. This can sometimes be solved by providing a temporary directory, num_cpus, and num_gpus to Ray Tune via the CLI: --raytune-temp-dir <absolute_path> --raytune-num-cpus <int> --raytune-num-gpus <int>."
|
456 |
+
)
|
457 |
+
else:
|
458 |
+
raise e
|
459 |
+
else:
|
460 |
+
logger.info("Ray is already initialized.")
|
461 |
+
|
462 |
+
format_kwargs = dict(
|
463 |
+
no_header_row=args.no_header_row,
|
464 |
+
smiles_cols=args.smiles_columns,
|
465 |
+
rxn_cols=args.reaction_columns,
|
466 |
+
target_cols=args.target_columns,
|
467 |
+
ignore_cols=args.ignore_columns,
|
468 |
+
splits_col=args.splits_column,
|
469 |
+
weight_col=args.weight_column,
|
470 |
+
bounded=args.loss_function is not None and "bounded" in args.loss_function,
|
471 |
+
)
|
472 |
+
|
473 |
+
featurization_kwargs = dict(
|
474 |
+
molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
|
475 |
+
)
|
476 |
+
|
477 |
+
train_data, val_data, test_data = build_splits(args, format_kwargs, featurization_kwargs)
|
478 |
+
train_dset, val_dset, test_dset = build_datasets(args, train_data[0], val_data[0], test_data[0])
|
479 |
+
|
480 |
+
input_transforms = normalize_inputs(train_dset, val_dset, args)
|
481 |
+
|
482 |
+
if "regression" in args.task_type:
|
483 |
+
output_scaler = train_dset.normalize_targets()
|
484 |
+
val_dset.normalize_targets(output_scaler)
|
485 |
+
logger.info(f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}")
|
486 |
+
output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
|
487 |
+
else:
|
488 |
+
output_transform = None
|
489 |
+
|
490 |
+
train_loader = build_dataloader(
|
491 |
+
train_dset, args.batch_size, args.num_workers, seed=args.data_seed
|
492 |
+
)
|
493 |
+
|
494 |
+
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
|
495 |
+
monitor_mode = "max" if model.metrics[0].higher_is_better else "min"
|
496 |
+
|
497 |
+
results = tune_model(
|
498 |
+
args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
|
499 |
+
)
|
500 |
+
|
501 |
+
best_result = results.get_best_result()
|
502 |
+
best_config = best_result.config["train_loop_config"]
|
503 |
+
best_checkpoint_path = Path(best_result.checkpoint.path) / "checkpoint.ckpt"
|
504 |
+
|
505 |
+
best_config_save_path = args.hpopt_save_dir / "best_config.toml"
|
506 |
+
best_checkpoint_save_path = args.hpopt_save_dir / "best_checkpoint.ckpt"
|
507 |
+
all_progress_save_path = args.hpopt_save_dir / "all_progress.csv"
|
508 |
+
|
509 |
+
logger.info(f"Best hyperparameters saved to: '{best_config_save_path}'")
|
510 |
+
|
511 |
+
args = update_args_with_config(args, best_config)
|
512 |
+
|
513 |
+
args = TrainSubcommand.parser.parse_known_args(namespace=args)[0]
|
514 |
+
save_config(TrainSubcommand.parser, args, best_config_save_path)
|
515 |
+
|
516 |
+
logger.info(
|
517 |
+
f"Best hyperparameter configuration checkpoint saved to '{best_checkpoint_save_path}'"
|
518 |
+
)
|
519 |
+
|
520 |
+
shutil.copyfile(best_checkpoint_path, best_checkpoint_save_path)
|
521 |
+
|
522 |
+
logger.info(f"Hyperparameter optimization results saved to '{all_progress_save_path}'")
|
523 |
+
|
524 |
+
result_df = results.get_dataframe()
|
525 |
+
|
526 |
+
result_df.to_csv(all_progress_save_path, index=False)
|
527 |
+
|
528 |
+
ray.shutdown()
|
529 |
+
|
530 |
+
|
531 |
+
if __name__ == "__main__":
|
532 |
+
parser = ArgumentParser()
|
533 |
+
parser = HpoptSubcommand.add_args(parser)
|
534 |
+
|
535 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
536 |
+
args = parser.parse_args()
|
537 |
+
HpoptSubcommand.func(args)
|
chemprop-updated/chemprop/cli/main.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from pathlib import Path
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from configargparse import ArgumentParser
|
6 |
+
|
7 |
+
from chemprop.cli.conf import LOG_DIR, LOG_LEVELS, NOW
|
8 |
+
from chemprop.cli.convert import ConvertSubcommand
|
9 |
+
from chemprop.cli.fingerprint import FingerprintSubcommand
|
10 |
+
from chemprop.cli.hpopt import HpoptSubcommand
|
11 |
+
from chemprop.cli.predict import PredictSubcommand
|
12 |
+
from chemprop.cli.train import TrainSubcommand
|
13 |
+
from chemprop.cli.utils import pop_attr
|
14 |
+
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
SUBCOMMANDS = [
|
18 |
+
TrainSubcommand,
|
19 |
+
PredictSubcommand,
|
20 |
+
ConvertSubcommand,
|
21 |
+
FingerprintSubcommand,
|
22 |
+
HpoptSubcommand,
|
23 |
+
]
|
24 |
+
|
25 |
+
|
26 |
+
def construct_parser():
|
27 |
+
parser = ArgumentParser()
|
28 |
+
subparsers = parser.add_subparsers(title="mode", dest="mode", required=True)
|
29 |
+
|
30 |
+
parent = ArgumentParser(add_help=False)
|
31 |
+
parent.add_argument(
|
32 |
+
"--logfile",
|
33 |
+
"--log",
|
34 |
+
nargs="?",
|
35 |
+
const="default",
|
36 |
+
help=f"Path to which the log file should be written (specifying just the flag alone will automatically log to a file ``{LOG_DIR}/MODE/TIMESTAMP.log`` , where 'MODE' is the CLI mode chosen, e.g., ``{LOG_DIR}/MODE/{NOW}.log``)",
|
37 |
+
)
|
38 |
+
parent.add_argument("-v", action="store_true", help="Increase verbosity level to DEBUG")
|
39 |
+
parent.add_argument(
|
40 |
+
"-q",
|
41 |
+
action="count",
|
42 |
+
default=0,
|
43 |
+
help="Decrease verbosity level to WARNING or ERROR if specified twice",
|
44 |
+
)
|
45 |
+
|
46 |
+
parents = [parent]
|
47 |
+
for subcommand in SUBCOMMANDS:
|
48 |
+
subcommand.add(subparsers, parents)
|
49 |
+
|
50 |
+
return parser
|
51 |
+
|
52 |
+
|
53 |
+
def main():
|
54 |
+
parser = construct_parser()
|
55 |
+
args = parser.parse_args()
|
56 |
+
logfile, v_flag, q_count, mode, func = (
|
57 |
+
pop_attr(args, attr) for attr in ["logfile", "v", "q", "mode", "func"]
|
58 |
+
)
|
59 |
+
|
60 |
+
if v_flag and q_count:
|
61 |
+
parser.error("The -v and -q options cannot be used together.")
|
62 |
+
|
63 |
+
match logfile:
|
64 |
+
case None:
|
65 |
+
handler = logging.StreamHandler(sys.stderr)
|
66 |
+
case "default":
|
67 |
+
(LOG_DIR / mode).mkdir(parents=True, exist_ok=True)
|
68 |
+
handler = logging.FileHandler(str(LOG_DIR / mode / f"{NOW}.log"))
|
69 |
+
case _:
|
70 |
+
Path(logfile).parent.mkdir(parents=True, exist_ok=True)
|
71 |
+
handler = logging.FileHandler(logfile)
|
72 |
+
|
73 |
+
verbosity = q_count * -1 if q_count else (1 if v_flag else 0)
|
74 |
+
logging_level = LOG_LEVELS.get(verbosity, logging.ERROR)
|
75 |
+
logging.basicConfig(
|
76 |
+
handlers=[handler],
|
77 |
+
format="%(asctime)s - %(levelname)s:%(name)s - %(message)s",
|
78 |
+
level=logging_level,
|
79 |
+
datefmt="%Y-%m-%dT%H:%M:%S",
|
80 |
+
force=True,
|
81 |
+
)
|
82 |
+
|
83 |
+
logger.info(f"Running in mode '{mode}' with args: {vars(args)}")
|
84 |
+
|
85 |
+
func(args)
|
chemprop-updated/chemprop/cli/predict.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import ArgumentError, ArgumentParser, Namespace
|
2 |
+
import logging
|
3 |
+
from pathlib import Path
|
4 |
+
import sys
|
5 |
+
from typing import Iterator
|
6 |
+
|
7 |
+
from lightning import pytorch as pl
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from chemprop import data
|
13 |
+
from chemprop.cli.common import (
|
14 |
+
add_common_args,
|
15 |
+
find_models,
|
16 |
+
process_common_args,
|
17 |
+
validate_common_args,
|
18 |
+
)
|
19 |
+
from chemprop.cli.utils import LookupAction, Subcommand, build_data_from_files, make_dataset
|
20 |
+
from chemprop.models.utils import load_model, load_output_columns
|
21 |
+
from chemprop.nn.metrics import LossFunctionRegistry
|
22 |
+
from chemprop.nn.predictors import EvidentialFFN, MulticlassClassificationFFN, MveFFN
|
23 |
+
from chemprop.uncertainty import (
|
24 |
+
MVEWeightingCalibrator,
|
25 |
+
NoUncertaintyEstimator,
|
26 |
+
RegressionCalibrator,
|
27 |
+
RegressionEvaluator,
|
28 |
+
UncertaintyCalibratorRegistry,
|
29 |
+
UncertaintyEstimatorRegistry,
|
30 |
+
UncertaintyEvaluatorRegistry,
|
31 |
+
)
|
32 |
+
from chemprop.utils import Factory
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__)
|
35 |
+
|
36 |
+
|
37 |
+
class PredictSubcommand(Subcommand):
|
38 |
+
COMMAND = "predict"
|
39 |
+
HELP = "use a pretrained chemprop model for prediction"
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
43 |
+
parser = add_common_args(parser)
|
44 |
+
return add_predict_args(parser)
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def func(cls, args: Namespace):
|
48 |
+
args = process_common_args(args)
|
49 |
+
validate_common_args(args)
|
50 |
+
args = process_predict_args(args)
|
51 |
+
main(args)
|
52 |
+
|
53 |
+
|
54 |
+
def add_predict_args(parser: ArgumentParser) -> ArgumentParser:
|
55 |
+
parser.add_argument(
|
56 |
+
"-i",
|
57 |
+
"--test-path",
|
58 |
+
required=True,
|
59 |
+
type=Path,
|
60 |
+
help="Path to an input CSV file containing SMILES",
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"-o",
|
64 |
+
"--output",
|
65 |
+
"--preds-path",
|
66 |
+
type=Path,
|
67 |
+
help="Specify path to which predictions will be saved. If the file extension is .pkl, it will be saved as a pickle file. Otherwise, chemprop will save predictions as a CSV. If multiple models are used to make predictions, the average predictions will be saved in the file, and another file ending in '_individual' with the same file extension will save the predictions for each individual model, with the column names being the target names appended with the model index (e.g., '_model_<index>').",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--drop-extra-columns",
|
71 |
+
action="store_true",
|
72 |
+
help="Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns",
|
73 |
+
)
|
74 |
+
parser.add_argument(
|
75 |
+
"--model-paths",
|
76 |
+
"--model-path",
|
77 |
+
required=True,
|
78 |
+
type=Path,
|
79 |
+
nargs="+",
|
80 |
+
help="Location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, will recursively search and predict on all found (.pt) models.",
|
81 |
+
)
|
82 |
+
|
83 |
+
unc_args = parser.add_argument_group("Uncertainty and calibration args")
|
84 |
+
unc_args.add_argument(
|
85 |
+
"--cal-path", type=Path, help="Path to data file to be used for uncertainty calibration."
|
86 |
+
)
|
87 |
+
unc_args.add_argument(
|
88 |
+
"--uncertainty-method",
|
89 |
+
default="none",
|
90 |
+
action=LookupAction(UncertaintyEstimatorRegistry),
|
91 |
+
help="The method of calculating uncertainty.",
|
92 |
+
)
|
93 |
+
unc_args.add_argument(
|
94 |
+
"--calibration-method",
|
95 |
+
action=LookupAction(UncertaintyCalibratorRegistry),
|
96 |
+
help="The method used for calibrating the uncertainty calculated with uncertainty method.",
|
97 |
+
)
|
98 |
+
unc_args.add_argument(
|
99 |
+
"--evaluation-methods",
|
100 |
+
"--evaluation-method",
|
101 |
+
nargs="+",
|
102 |
+
action=LookupAction(UncertaintyEvaluatorRegistry),
|
103 |
+
help="The methods used for evaluating the uncertainty performance if the test data provided includes targets. Available methods are [nll, miscalibration_area, ence, spearman] or any available classification or multiclass metric.",
|
104 |
+
)
|
105 |
+
# unc_args.add_argument(
|
106 |
+
# "--evaluation-scores-path", help="Location to save the results of uncertainty evaluations."
|
107 |
+
# )
|
108 |
+
unc_args.add_argument(
|
109 |
+
"--uncertainty-dropout-p",
|
110 |
+
type=float,
|
111 |
+
default=0.1,
|
112 |
+
help="The probability to use for Monte Carlo dropout uncertainty estimation.",
|
113 |
+
)
|
114 |
+
unc_args.add_argument(
|
115 |
+
"--dropout-sampling-size",
|
116 |
+
type=int,
|
117 |
+
default=10,
|
118 |
+
help="The number of samples to use for Monte Carlo dropout uncertainty estimation. Distinct from the dropout used during training.",
|
119 |
+
)
|
120 |
+
unc_args.add_argument(
|
121 |
+
"--calibration-interval-percentile",
|
122 |
+
type=float,
|
123 |
+
default=95,
|
124 |
+
help="Sets the percentile used in the calibration methods. Must be in the range (1, 100).",
|
125 |
+
)
|
126 |
+
unc_args.add_argument(
|
127 |
+
"--conformal-alpha",
|
128 |
+
type=float,
|
129 |
+
default=0.1,
|
130 |
+
help="Target error rate for conformal prediction. Must be in the range (0, 1).",
|
131 |
+
)
|
132 |
+
# TODO: Decide if we want to implment this in v2.1.x
|
133 |
+
# unc_args.add_argument(
|
134 |
+
# "--regression-calibrator-metric",
|
135 |
+
# choices=["stdev", "interval"],
|
136 |
+
# help="Regression calibrators can output either a stdev or an inverval.",
|
137 |
+
# )
|
138 |
+
unc_args.add_argument(
|
139 |
+
"--cal-descriptors-path",
|
140 |
+
nargs="+",
|
141 |
+
action="append",
|
142 |
+
help="Path to extra descriptors to concatenate to learned representation in calibration dataset.",
|
143 |
+
)
|
144 |
+
# TODO: Add in v2.1.x
|
145 |
+
# unc_args.add_argument(
|
146 |
+
# "--calibration-phase-features-path",
|
147 |
+
# help=" ",
|
148 |
+
# )
|
149 |
+
unc_args.add_argument(
|
150 |
+
"--cal-atom-features-path",
|
151 |
+
nargs="+",
|
152 |
+
action="append",
|
153 |
+
help="Path to the extra atom features in calibration dataset.",
|
154 |
+
)
|
155 |
+
unc_args.add_argument(
|
156 |
+
"--cal-atom-descriptors-path",
|
157 |
+
nargs="+",
|
158 |
+
action="append",
|
159 |
+
help="Path to the extra atom descriptors in calibration dataset.",
|
160 |
+
)
|
161 |
+
unc_args.add_argument(
|
162 |
+
"--cal-bond-features-path",
|
163 |
+
nargs="+",
|
164 |
+
action="append",
|
165 |
+
help="Path to the extra bond descriptors in calibration dataset.",
|
166 |
+
)
|
167 |
+
|
168 |
+
return parser
|
169 |
+
|
170 |
+
|
171 |
+
def process_predict_args(args: Namespace) -> Namespace:
|
172 |
+
if args.test_path.suffix not in [".csv"]:
|
173 |
+
raise ArgumentError(
|
174 |
+
argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
|
175 |
+
)
|
176 |
+
if args.output is None:
|
177 |
+
args.output = args.test_path.parent / (args.test_path.stem + "_preds.csv")
|
178 |
+
if args.output.suffix not in [".csv", ".pkl"]:
|
179 |
+
raise ArgumentError(
|
180 |
+
argument=None, message=f"Output must be a CSV or Pickle file. Got {args.output}"
|
181 |
+
)
|
182 |
+
return args
|
183 |
+
|
184 |
+
|
185 |
+
def prepare_data_loader(
|
186 |
+
args: Namespace, multicomponent: bool, is_calibration: bool, format_kwargs: dict
|
187 |
+
):
|
188 |
+
data_path = args.cal_path if is_calibration else args.test_path
|
189 |
+
descriptors_path = args.cal_descriptors_path if is_calibration else args.descriptors_path
|
190 |
+
atom_feats_path = args.cal_atom_features_path if is_calibration else args.atom_features_path
|
191 |
+
bond_feats_path = args.cal_bond_features_path if is_calibration else args.bond_features_path
|
192 |
+
atom_descs_path = (
|
193 |
+
args.cal_atom_descriptors_path if is_calibration else args.atom_descriptors_path
|
194 |
+
)
|
195 |
+
|
196 |
+
featurization_kwargs = dict(
|
197 |
+
molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
|
198 |
+
)
|
199 |
+
|
200 |
+
datas = build_data_from_files(
|
201 |
+
data_path,
|
202 |
+
**format_kwargs,
|
203 |
+
p_descriptors=descriptors_path,
|
204 |
+
p_atom_feats=atom_feats_path,
|
205 |
+
p_bond_feats=bond_feats_path,
|
206 |
+
p_atom_descs=atom_descs_path,
|
207 |
+
**featurization_kwargs,
|
208 |
+
)
|
209 |
+
|
210 |
+
dsets = [make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in datas]
|
211 |
+
dset = data.MulticomponentDataset(dsets) if multicomponent else dsets[0]
|
212 |
+
|
213 |
+
return data.build_dataloader(dset, args.batch_size, args.num_workers, shuffle=False)
|
214 |
+
|
215 |
+
|
216 |
+
def make_prediction_for_models(
|
217 |
+
args: Namespace, model_paths: Iterator[Path], multicomponent: bool, output_path: Path
|
218 |
+
):
|
219 |
+
model = load_model(model_paths[0], multicomponent)
|
220 |
+
output_columns = load_output_columns(model_paths[0])
|
221 |
+
bounded = any(
|
222 |
+
isinstance(model.criterion, LossFunctionRegistry[loss_function])
|
223 |
+
for loss_function in LossFunctionRegistry.keys()
|
224 |
+
if "bounded" in loss_function
|
225 |
+
)
|
226 |
+
format_kwargs = dict(
|
227 |
+
no_header_row=args.no_header_row,
|
228 |
+
smiles_cols=args.smiles_columns,
|
229 |
+
rxn_cols=args.reaction_columns,
|
230 |
+
ignore_cols=None,
|
231 |
+
splits_col=None,
|
232 |
+
weight_col=None,
|
233 |
+
bounded=bounded,
|
234 |
+
)
|
235 |
+
format_kwargs["target_cols"] = output_columns if args.evaluation_methods is not None else []
|
236 |
+
test_loader = prepare_data_loader(args, multicomponent, False, format_kwargs)
|
237 |
+
logger.info(f"test size: {len(test_loader.dataset)}")
|
238 |
+
if args.cal_path is not None:
|
239 |
+
format_kwargs["target_cols"] = output_columns
|
240 |
+
cal_loader = prepare_data_loader(args, multicomponent, True, format_kwargs)
|
241 |
+
logger.info(f"calibration size: {len(cal_loader.dataset)}")
|
242 |
+
|
243 |
+
uncertainty_estimator = Factory.build(
|
244 |
+
UncertaintyEstimatorRegistry[args.uncertainty_method],
|
245 |
+
ensemble_size=args.dropout_sampling_size,
|
246 |
+
dropout=args.uncertainty_dropout_p,
|
247 |
+
)
|
248 |
+
|
249 |
+
models = [load_model(model_path, multicomponent) for model_path in model_paths]
|
250 |
+
trainer = pl.Trainer(
|
251 |
+
logger=False, enable_progress_bar=True, accelerator=args.accelerator, devices=args.devices
|
252 |
+
)
|
253 |
+
test_individual_preds, test_individual_uncs = uncertainty_estimator(
|
254 |
+
test_loader, models, trainer
|
255 |
+
)
|
256 |
+
test_preds = torch.mean(test_individual_preds, dim=0)
|
257 |
+
if not isinstance(uncertainty_estimator, NoUncertaintyEstimator):
|
258 |
+
test_uncs = torch.mean(test_individual_uncs, dim=0)
|
259 |
+
else:
|
260 |
+
test_uncs = None
|
261 |
+
|
262 |
+
if args.calibration_method is not None:
|
263 |
+
uncertainty_calibrator = Factory.build(
|
264 |
+
UncertaintyCalibratorRegistry[args.calibration_method],
|
265 |
+
p=args.calibration_interval_percentile / 100,
|
266 |
+
alpha=args.conformal_alpha,
|
267 |
+
)
|
268 |
+
cal_targets = cal_loader.dataset.Y
|
269 |
+
cal_mask = torch.from_numpy(np.isfinite(cal_targets))
|
270 |
+
cal_targets = np.nan_to_num(cal_targets, nan=0.0)
|
271 |
+
cal_targets = torch.from_numpy(cal_targets)
|
272 |
+
cal_individual_preds, cal_individual_uncs = uncertainty_estimator(
|
273 |
+
cal_loader, models, trainer
|
274 |
+
)
|
275 |
+
cal_preds = torch.mean(cal_individual_preds, dim=0)
|
276 |
+
cal_uncs = torch.mean(cal_individual_uncs, dim=0)
|
277 |
+
if isinstance(uncertainty_calibrator, MVEWeightingCalibrator):
|
278 |
+
uncertainty_calibrator.fit(cal_preds, cal_individual_uncs, cal_targets, cal_mask)
|
279 |
+
test_uncs = uncertainty_calibrator.apply(cal_individual_uncs)
|
280 |
+
else:
|
281 |
+
if isinstance(uncertainty_calibrator, RegressionCalibrator):
|
282 |
+
uncertainty_calibrator.fit(cal_preds, cal_uncs, cal_targets, cal_mask)
|
283 |
+
else:
|
284 |
+
uncertainty_calibrator.fit(cal_uncs, cal_targets, cal_mask)
|
285 |
+
test_uncs = uncertainty_calibrator.apply(test_uncs)
|
286 |
+
for i in range(test_individual_uncs.shape[0]):
|
287 |
+
test_individual_uncs[i] = uncertainty_calibrator.apply(test_individual_uncs[i])
|
288 |
+
|
289 |
+
if args.evaluation_methods is not None:
|
290 |
+
uncertainty_evaluators = [
|
291 |
+
Factory.build(UncertaintyEvaluatorRegistry[method])
|
292 |
+
for method in args.evaluation_methods
|
293 |
+
]
|
294 |
+
logger.info("Uncertainty evaluation metric:")
|
295 |
+
for evaluator in uncertainty_evaluators:
|
296 |
+
test_targets = test_loader.dataset.Y
|
297 |
+
test_mask = torch.from_numpy(np.isfinite(test_targets))
|
298 |
+
test_targets = np.nan_to_num(test_targets, nan=0.0)
|
299 |
+
test_targets = torch.from_numpy(test_targets)
|
300 |
+
if isinstance(evaluator, RegressionEvaluator):
|
301 |
+
metric_value = evaluator.evaluate(test_preds, test_uncs, test_targets, test_mask)
|
302 |
+
else:
|
303 |
+
metric_value = evaluator.evaluate(test_uncs, test_targets, test_mask)
|
304 |
+
logger.info(f"{evaluator.alias}: {metric_value.tolist()}")
|
305 |
+
|
306 |
+
if args.uncertainty_method == "none" and (
|
307 |
+
isinstance(model.predictor, MveFFN) or isinstance(model.predictor, EvidentialFFN)
|
308 |
+
):
|
309 |
+
test_preds = test_preds[..., 0]
|
310 |
+
test_individual_preds = test_individual_preds[..., 0]
|
311 |
+
|
312 |
+
if output_columns is None:
|
313 |
+
output_columns = [
|
314 |
+
f"pred_{i}" for i in range(test_preds.shape[1])
|
315 |
+
] # TODO: need to improve this for cases like multi-task MVE and multi-task multiclass
|
316 |
+
|
317 |
+
save_predictions(args, model, output_columns, test_preds, test_uncs, output_path)
|
318 |
+
|
319 |
+
if len(model_paths) > 1:
|
320 |
+
save_individual_predictions(
|
321 |
+
args,
|
322 |
+
model,
|
323 |
+
model_paths,
|
324 |
+
output_columns,
|
325 |
+
test_individual_preds,
|
326 |
+
test_individual_uncs,
|
327 |
+
output_path,
|
328 |
+
)
|
329 |
+
|
330 |
+
|
331 |
+
def save_predictions(args, model, output_columns, test_preds, test_uncs, output_path):
|
332 |
+
unc_columns = [f"{col}_unc" for col in output_columns]
|
333 |
+
|
334 |
+
if isinstance(model.predictor, MulticlassClassificationFFN):
|
335 |
+
output_columns = output_columns + [f"{col}_prob" for col in output_columns]
|
336 |
+
predicted_class_labels = test_preds.argmax(axis=-1)
|
337 |
+
formatted_probability_strings = np.apply_along_axis(
|
338 |
+
lambda x: ",".join(map(str, x)), 2, test_preds.numpy()
|
339 |
+
)
|
340 |
+
test_preds = np.concatenate(
|
341 |
+
(predicted_class_labels, formatted_probability_strings), axis=-1
|
342 |
+
)
|
343 |
+
|
344 |
+
df_test = pd.read_csv(
|
345 |
+
args.test_path, header=None if args.no_header_row else "infer", index_col=False
|
346 |
+
)
|
347 |
+
df_test[output_columns] = test_preds
|
348 |
+
|
349 |
+
if args.uncertainty_method not in ["none", "classification"]:
|
350 |
+
df_test[unc_columns] = np.round(test_uncs, 6)
|
351 |
+
|
352 |
+
if output_path.suffix == ".pkl":
|
353 |
+
df_test = df_test.reset_index(drop=True)
|
354 |
+
df_test.to_pickle(output_path)
|
355 |
+
else:
|
356 |
+
df_test.to_csv(output_path, index=False)
|
357 |
+
logger.info(f"Predictions saved to '{output_path}'")
|
358 |
+
|
359 |
+
|
360 |
+
def save_individual_predictions(
|
361 |
+
args,
|
362 |
+
model,
|
363 |
+
model_paths,
|
364 |
+
output_columns,
|
365 |
+
test_individual_preds,
|
366 |
+
test_individual_uncs,
|
367 |
+
output_path,
|
368 |
+
):
|
369 |
+
unc_columns = [
|
370 |
+
f"{col}_unc_model_{i}" for i in range(len(model_paths)) for col in output_columns
|
371 |
+
]
|
372 |
+
|
373 |
+
if isinstance(model.predictor, MulticlassClassificationFFN):
|
374 |
+
output_columns = [
|
375 |
+
item
|
376 |
+
for i in range(len(model_paths))
|
377 |
+
for col in output_columns
|
378 |
+
for item in (f"{col}_model_{i}", f"{col}_prob_model_{i}")
|
379 |
+
]
|
380 |
+
|
381 |
+
predicted_class_labels = test_individual_preds.argmax(axis=-1)
|
382 |
+
formatted_probability_strings = np.apply_along_axis(
|
383 |
+
lambda x: ",".join(map(str, x)), 3, test_individual_preds.numpy()
|
384 |
+
)
|
385 |
+
test_individual_preds = np.concatenate(
|
386 |
+
(predicted_class_labels, formatted_probability_strings), axis=-1
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
output_columns = [
|
390 |
+
f"{col}_model_{i}" for i in range(len(model_paths)) for col in output_columns
|
391 |
+
]
|
392 |
+
|
393 |
+
m, n, t = test_individual_preds.shape
|
394 |
+
test_individual_preds = np.transpose(test_individual_preds, (1, 0, 2)).reshape(n, m * t)
|
395 |
+
df_test = pd.read_csv(
|
396 |
+
args.test_path, header=None if args.no_header_row else "infer", index_col=False
|
397 |
+
)
|
398 |
+
df_test[output_columns] = test_individual_preds
|
399 |
+
|
400 |
+
if args.uncertainty_method not in ["none", "classification", "ensemble"]:
|
401 |
+
m, n, t = test_individual_uncs.shape
|
402 |
+
test_individual_uncs = np.transpose(test_individual_uncs, (1, 0, 2)).reshape(n, m * t)
|
403 |
+
df_test[unc_columns] = np.round(test_individual_uncs, 6)
|
404 |
+
|
405 |
+
output_path = output_path.parent / Path(
|
406 |
+
str(args.output.stem) + "_individual" + str(output_path.suffix)
|
407 |
+
)
|
408 |
+
if output_path.suffix == ".pkl":
|
409 |
+
df_test = df_test.reset_index(drop=True)
|
410 |
+
df_test.to_pickle(output_path)
|
411 |
+
else:
|
412 |
+
df_test.to_csv(output_path, index=False)
|
413 |
+
logger.info(f"Individual predictions saved to '{output_path}'")
|
414 |
+
for i, model_path in enumerate(model_paths):
|
415 |
+
logger.info(
|
416 |
+
f"Results from model path {model_path} are saved under the column name ending with 'model_{i}'"
|
417 |
+
)
|
418 |
+
|
419 |
+
|
420 |
+
def main(args):
|
421 |
+
match (args.smiles_columns, args.reaction_columns):
|
422 |
+
case [None, None]:
|
423 |
+
n_components = 1
|
424 |
+
case [_, None]:
|
425 |
+
n_components = len(args.smiles_columns)
|
426 |
+
case [None, _]:
|
427 |
+
n_components = len(args.reaction_columns)
|
428 |
+
case _:
|
429 |
+
n_components = len(args.smiles_columns) + len(args.reaction_columns)
|
430 |
+
|
431 |
+
multicomponent = n_components > 1
|
432 |
+
|
433 |
+
model_paths = find_models(args.model_paths)
|
434 |
+
|
435 |
+
make_prediction_for_models(args, model_paths, multicomponent, output_path=args.output)
|
436 |
+
|
437 |
+
|
438 |
+
if __name__ == "__main__":
|
439 |
+
parser = ArgumentParser()
|
440 |
+
parser = PredictSubcommand.add_args(parser)
|
441 |
+
|
442 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
443 |
+
args = parser.parse_args()
|
444 |
+
args = PredictSubcommand.func(args)
|
chemprop-updated/chemprop/cli/train.py
ADDED
@@ -0,0 +1,1340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from io import StringIO
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
from tempfile import TemporaryDirectory
|
8 |
+
|
9 |
+
from configargparse import ArgumentError, ArgumentParser, Namespace
|
10 |
+
from lightning import pytorch as pl
|
11 |
+
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
|
12 |
+
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
|
13 |
+
from lightning.pytorch.strategies import DDPStrategy
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from rich.console import Console
|
17 |
+
from rich.table import Column, Table
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
from chemprop.cli.common import (
|
22 |
+
add_common_args,
|
23 |
+
find_models,
|
24 |
+
process_common_args,
|
25 |
+
validate_common_args,
|
26 |
+
)
|
27 |
+
from chemprop.cli.conf import CHEMPROP_TRAIN_DIR, NOW
|
28 |
+
from chemprop.cli.utils import (
|
29 |
+
LookupAction,
|
30 |
+
Subcommand,
|
31 |
+
build_data_from_files,
|
32 |
+
get_column_names,
|
33 |
+
make_dataset,
|
34 |
+
parse_indices,
|
35 |
+
)
|
36 |
+
from chemprop.cli.utils.args import uppercase
|
37 |
+
from chemprop.data import (
|
38 |
+
MoleculeDataset,
|
39 |
+
MolGraphDataset,
|
40 |
+
MulticomponentDataset,
|
41 |
+
ReactionDatapoint,
|
42 |
+
SplitType,
|
43 |
+
build_dataloader,
|
44 |
+
make_split_indices,
|
45 |
+
split_data_by_indices,
|
46 |
+
)
|
47 |
+
from chemprop.data.datasets import _MolGraphDatasetMixin
|
48 |
+
from chemprop.models import MPNN, MulticomponentMPNN, save_model
|
49 |
+
from chemprop.nn import AggregationRegistry, LossFunctionRegistry, MetricRegistry, PredictorRegistry
|
50 |
+
from chemprop.nn.message_passing import (
|
51 |
+
AtomMessagePassing,
|
52 |
+
BondMessagePassing,
|
53 |
+
MulticomponentMessagePassing,
|
54 |
+
)
|
55 |
+
from chemprop.nn.transforms import GraphTransform, ScaleTransform, UnscaleTransform
|
56 |
+
from chemprop.nn.utils import Activation
|
57 |
+
from chemprop.utils import Factory
|
58 |
+
|
59 |
+
logger = logging.getLogger(__name__)
|
60 |
+
|
61 |
+
|
62 |
+
_CV_REMOVAL_ERROR = (
|
63 |
+
"The -k/--num-folds argument was removed in v2.1.0 - use --num-replicates instead."
|
64 |
+
)
|
65 |
+
|
66 |
+
|
67 |
+
class TrainSubcommand(Subcommand):
|
68 |
+
COMMAND = "train"
|
69 |
+
HELP = "Train a chemprop model."
|
70 |
+
parser = None
|
71 |
+
|
72 |
+
@classmethod
|
73 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
74 |
+
parser = add_common_args(parser)
|
75 |
+
parser = add_train_args(parser)
|
76 |
+
cls.parser = parser
|
77 |
+
return parser
|
78 |
+
|
79 |
+
@classmethod
|
80 |
+
def func(cls, args: Namespace):
|
81 |
+
args = process_common_args(args)
|
82 |
+
validate_common_args(args)
|
83 |
+
args = process_train_args(args)
|
84 |
+
validate_train_args(args)
|
85 |
+
|
86 |
+
args.output_dir.mkdir(exist_ok=True, parents=True)
|
87 |
+
config_path = args.output_dir / "config.toml"
|
88 |
+
save_config(cls.parser, args, config_path)
|
89 |
+
main(args)
|
90 |
+
|
91 |
+
|
92 |
+
def add_train_args(parser: ArgumentParser) -> ArgumentParser:
|
93 |
+
parser.add_argument(
|
94 |
+
"--config-path",
|
95 |
+
type=Path,
|
96 |
+
is_config_file=True,
|
97 |
+
help="Path to a configuration file (command line arguments override values in the configuration file)",
|
98 |
+
)
|
99 |
+
parser.add_argument(
|
100 |
+
"-i",
|
101 |
+
"--data-path",
|
102 |
+
type=Path,
|
103 |
+
help="Path to an input CSV file containing SMILES and the associated target values",
|
104 |
+
)
|
105 |
+
parser.add_argument(
|
106 |
+
"-o",
|
107 |
+
"--output-dir",
|
108 |
+
"--save-dir",
|
109 |
+
type=Path,
|
110 |
+
help="Directory where training outputs will be saved (defaults to ``CURRENT_DIRECTORY/chemprop_training/STEM_OF_INPUT/TIME_STAMP``)",
|
111 |
+
)
|
112 |
+
parser.add_argument(
|
113 |
+
"--remove-checkpoints",
|
114 |
+
action="store_true",
|
115 |
+
help="Remove intermediate checkpoint files after training is complete.",
|
116 |
+
)
|
117 |
+
|
118 |
+
# TODO: Add in v2.1; see if we can tell lightning how often to log training loss
|
119 |
+
# parser.add_argument(
|
120 |
+
# "--log-frequency",
|
121 |
+
# type=int,
|
122 |
+
# default=10,
|
123 |
+
# help="The number of batches between each logging of the training loss.",
|
124 |
+
# )
|
125 |
+
|
126 |
+
transfer_args = parser.add_argument_group("transfer learning args")
|
127 |
+
transfer_args.add_argument(
|
128 |
+
"--checkpoint",
|
129 |
+
type=Path,
|
130 |
+
nargs="+",
|
131 |
+
help="Path to checkpoint(s) or model file(s) for loading and overwriting weights. Accepts a single pre-trained model checkpoint (.ckpt), a single model file (.pt), a directory containing such files, or a list of paths and directories. If a directory is provided, it will recursively search for and use all (.pt) files found for prediction.",
|
132 |
+
)
|
133 |
+
transfer_args.add_argument(
|
134 |
+
"--freeze-encoder",
|
135 |
+
action="store_true",
|
136 |
+
help="Freeze the message passing layer from the checkpoint model (specified by ``--checkpoint``).",
|
137 |
+
)
|
138 |
+
transfer_args.add_argument(
|
139 |
+
"--model-frzn",
|
140 |
+
help="Path to model checkpoint file to be loaded for overwriting and freezing weights. By default, all MPNN weights are frozen with this option.",
|
141 |
+
)
|
142 |
+
transfer_args.add_argument(
|
143 |
+
"--frzn-ffn-layers",
|
144 |
+
type=int,
|
145 |
+
default=0,
|
146 |
+
help="Freeze the first ``n`` layers of the FFN from the checkpoint model (specified by ``--checkpoint``). The message passing layer should also be frozen with ``--freeze-encoder``.",
|
147 |
+
)
|
148 |
+
# transfer_args.add_argument(
|
149 |
+
# "--freeze-first-only",
|
150 |
+
# action="store_true",
|
151 |
+
# help="Determines whether or not to use checkpoint_frzn for just the first encoder. Default (False) is to use the checkpoint to freeze all encoders. (only relevant for number_of_molecules > 1, where checkpoint model has number_of_molecules = 1)",
|
152 |
+
# )
|
153 |
+
|
154 |
+
# TODO: Add in v2.1
|
155 |
+
# parser.add_argument(
|
156 |
+
# "--resume-experiment",
|
157 |
+
# action="store_true",
|
158 |
+
# help="Whether to resume the experiment. Loads test results from any folds that have already been completed and skips training those folds.",
|
159 |
+
# )
|
160 |
+
# parser.add_argument(
|
161 |
+
# "--config-path",
|
162 |
+
# help="Path to a :code:`.json` file containing arguments. Any arguments present in the config file will override arguments specified via the command line or by the defaults.",
|
163 |
+
# )
|
164 |
+
parser.add_argument(
|
165 |
+
"--ensemble-size",
|
166 |
+
type=int,
|
167 |
+
default=1,
|
168 |
+
help="Number of models in ensemble for each splitting of data",
|
169 |
+
)
|
170 |
+
|
171 |
+
# TODO: Add in v2.2
|
172 |
+
# abt_args = parser.add_argument_group("atom/bond target args")
|
173 |
+
# abt_args.add_argument(
|
174 |
+
# "--is-atom-bond-targets",
|
175 |
+
# action="store_true",
|
176 |
+
# help="Whether this is atomic/bond properties prediction.",
|
177 |
+
# )
|
178 |
+
# abt_args.add_argument(
|
179 |
+
# "--no-adding-bond-types",
|
180 |
+
# action="store_true",
|
181 |
+
# help="Whether the bond types determined by RDKit molecules added to the output of bond targets. This option is intended to be used with the :code:`is_atom_bond_targets`.",
|
182 |
+
# )
|
183 |
+
# abt_args.add_argument(
|
184 |
+
# "--keeping-atom-map",
|
185 |
+
# action="store_true",
|
186 |
+
# help="Whether RDKit molecules keep the original atom mapping. This option is intended to be used when providing atom-mapped SMILES with the :code:`is_atom_bond_targets`.",
|
187 |
+
# )
|
188 |
+
# abt_args.add_argument(
|
189 |
+
# "--no-shared-atom-bond-ffn",
|
190 |
+
# action="store_true",
|
191 |
+
# help="Whether the FFN weights for atom and bond targets should be independent between tasks.",
|
192 |
+
# )
|
193 |
+
# abt_args.add_argument(
|
194 |
+
# "--weights-ffn-num-layers",
|
195 |
+
# type=int,
|
196 |
+
# default=2,
|
197 |
+
# help="Number of layers in FFN for determining weights used in constrained targets.",
|
198 |
+
# )
|
199 |
+
|
200 |
+
mp_args = parser.add_argument_group("message passing")
|
201 |
+
mp_args.add_argument(
|
202 |
+
"--message-hidden-dim", type=int, default=300, help="Hidden dimension of the messages"
|
203 |
+
)
|
204 |
+
mp_args.add_argument(
|
205 |
+
"--message-bias", action="store_true", help="Add bias to the message passing layers"
|
206 |
+
)
|
207 |
+
mp_args.add_argument("--depth", type=int, default=3, help="Number of message passing steps")
|
208 |
+
mp_args.add_argument(
|
209 |
+
"--undirected",
|
210 |
+
action="store_true",
|
211 |
+
help="Pass messages on undirected bonds/edges (always sum the two relevant bond vectors)",
|
212 |
+
)
|
213 |
+
mp_args.add_argument(
|
214 |
+
"--dropout",
|
215 |
+
type=float,
|
216 |
+
default=0.0,
|
217 |
+
help="Dropout probability in message passing/FFN layers",
|
218 |
+
)
|
219 |
+
mp_args.add_argument(
|
220 |
+
"--mpn-shared",
|
221 |
+
action="store_true",
|
222 |
+
help="Whether to use the same message passing neural network for all input molecules (only relevant if ``number_of_molecules`` > 1)",
|
223 |
+
)
|
224 |
+
mp_args.add_argument(
|
225 |
+
"--activation",
|
226 |
+
type=uppercase,
|
227 |
+
default="RELU",
|
228 |
+
choices=list(Activation.keys()),
|
229 |
+
help="Activation function in message passing/FFN layers",
|
230 |
+
)
|
231 |
+
mp_args.add_argument(
|
232 |
+
"--aggregation",
|
233 |
+
"--agg",
|
234 |
+
default="norm",
|
235 |
+
action=LookupAction(AggregationRegistry),
|
236 |
+
help="Aggregation mode to use during graph predictor",
|
237 |
+
)
|
238 |
+
mp_args.add_argument(
|
239 |
+
"--aggregation-norm",
|
240 |
+
type=float,
|
241 |
+
default=100,
|
242 |
+
help="Normalization factor by which to divide summed up atomic features for ``norm`` aggregation",
|
243 |
+
)
|
244 |
+
mp_args.add_argument(
|
245 |
+
"--atom-messages", action="store_true", help="Pass messages on atoms rather than bonds."
|
246 |
+
)
|
247 |
+
|
248 |
+
# TODO: Add in v2.1
|
249 |
+
# mpsolv_args = parser.add_argument_group("message passing with solvent")
|
250 |
+
# mpsolv_args.add_argument(
|
251 |
+
# "--reaction-solvent",
|
252 |
+
# action="store_true",
|
253 |
+
# help="Whether to adjust the MPNN layer to take as input a reaction and a molecule, and to encode them with separate MPNNs.",
|
254 |
+
# )
|
255 |
+
# mpsolv_args.add_argument(
|
256 |
+
# "--bias-solvent",
|
257 |
+
# action="store_true",
|
258 |
+
# help="Whether to add bias to linear layers for solvent MPN if :code:`reaction_solvent` is True.",
|
259 |
+
# )
|
260 |
+
# mpsolv_args.add_argument(
|
261 |
+
# "--hidden-size-solvent",
|
262 |
+
# type=int,
|
263 |
+
# default=300,
|
264 |
+
# help="Dimensionality of hidden layers in solvent MPN if :code:`reaction_solvent` is True.",
|
265 |
+
# )
|
266 |
+
# mpsolv_args.add_argument(
|
267 |
+
# "--depth-solvent",
|
268 |
+
# type=int,
|
269 |
+
# default=3,
|
270 |
+
# help="Number of message passing steps for solvent if :code:`reaction_solvent` is True.",
|
271 |
+
# )
|
272 |
+
|
273 |
+
ffn_args = parser.add_argument_group("FFN args")
|
274 |
+
ffn_args.add_argument(
|
275 |
+
"--ffn-hidden-dim", type=int, default=300, help="Hidden dimension in the FFN top model"
|
276 |
+
)
|
277 |
+
ffn_args.add_argument( # TODO: the default in v1 was 2. (see weights_ffn_num_layers option) Do we really want the default to now be 1?
|
278 |
+
"--ffn-num-layers", type=int, default=1, help="Number of layers in FFN top model"
|
279 |
+
)
|
280 |
+
# TODO: Decide if we want to implment this in v2
|
281 |
+
# ffn_args.add_argument(
|
282 |
+
# "--features-only",
|
283 |
+
# action="store_true",
|
284 |
+
# help="Use only the additional features in an FFN, no graph network.",
|
285 |
+
# )
|
286 |
+
|
287 |
+
extra_mpnn_args = parser.add_argument_group("extra MPNN args")
|
288 |
+
extra_mpnn_args.add_argument(
|
289 |
+
"--batch-norm", action="store_true", help="Turn on batch normalization after aggregation"
|
290 |
+
)
|
291 |
+
extra_mpnn_args.add_argument(
|
292 |
+
"--multiclass-num-classes",
|
293 |
+
type=int,
|
294 |
+
default=3,
|
295 |
+
help="Number of classes when running multiclass classification",
|
296 |
+
)
|
297 |
+
# TODO: Add in v2.1
|
298 |
+
# extra_mpnn_args.add_argument(
|
299 |
+
# "--spectral-activation",
|
300 |
+
# default="exp",
|
301 |
+
# choices=["softplus", "exp"],
|
302 |
+
# help="Indicates which function to use in task_type spectra training to constrain outputs to be positive.",
|
303 |
+
# )
|
304 |
+
|
305 |
+
train_data_args = parser.add_argument_group("training input data args")
|
306 |
+
train_data_args.add_argument(
|
307 |
+
"-w",
|
308 |
+
"--weight-column",
|
309 |
+
help="Name of the column in the input CSV containing individual data weights",
|
310 |
+
)
|
311 |
+
train_data_args.add_argument(
|
312 |
+
"--target-columns",
|
313 |
+
nargs="+",
|
314 |
+
help="Name of the columns containing target values (by default, uses all columns except the SMILES column and the ``ignore_columns``)",
|
315 |
+
)
|
316 |
+
train_data_args.add_argument(
|
317 |
+
"--ignore-columns",
|
318 |
+
nargs="+",
|
319 |
+
help="Name of the columns to ignore when ``target_columns`` is not provided",
|
320 |
+
)
|
321 |
+
train_data_args.add_argument(
|
322 |
+
"--no-cache",
|
323 |
+
action="store_true",
|
324 |
+
help="Turn off caching the featurized ``MolGraph`` s at the beginning of training",
|
325 |
+
)
|
326 |
+
train_data_args.add_argument(
|
327 |
+
"--splits-column",
|
328 |
+
help="Name of the column in the input CSV file containing 'train', 'val', or 'test' for each row.",
|
329 |
+
)
|
330 |
+
# TODO: Add in v2.1
|
331 |
+
# train_data_args.add_argument(
|
332 |
+
# "--spectra-phase-mask-path",
|
333 |
+
# help="Path to a file containing a phase mask array, used for excluding particular regions in spectra predictions.",
|
334 |
+
# )
|
335 |
+
|
336 |
+
train_args = parser.add_argument_group("training args")
|
337 |
+
train_args.add_argument(
|
338 |
+
"-t",
|
339 |
+
"--task-type",
|
340 |
+
default="regression",
|
341 |
+
action=LookupAction(PredictorRegistry),
|
342 |
+
help="Type of dataset (determines the default loss function used during training, defaults to ``regression``)",
|
343 |
+
)
|
344 |
+
train_args.add_argument(
|
345 |
+
"-l",
|
346 |
+
"--loss-function",
|
347 |
+
action=LookupAction(LossFunctionRegistry),
|
348 |
+
help="Loss function to use during training (will use the default loss function for the given task type if not specified)",
|
349 |
+
)
|
350 |
+
train_args.add_argument(
|
351 |
+
"--v-kl",
|
352 |
+
"--evidential-regularization",
|
353 |
+
type=float,
|
354 |
+
default=0.0,
|
355 |
+
help="Specify the value used in regularization for evidential loss function. The default value recommended by Soleimany et al. (2021) is 0.2. However, the optimal value is dataset-dependent, so it is recommended that users test different values to find the best value for their model.",
|
356 |
+
)
|
357 |
+
|
358 |
+
train_args.add_argument(
|
359 |
+
"--eps", type=float, default=1e-8, help="Evidential regularization epsilon"
|
360 |
+
)
|
361 |
+
train_args.add_argument(
|
362 |
+
"--alpha", type=float, default=0.1, help="Target error bounds for quantile interval loss"
|
363 |
+
)
|
364 |
+
# TODO: Add in v2.1
|
365 |
+
# train_args.add_argument( # TODO: Is threshold the same thing as the spectra target floor? I'm not sure but combined them.
|
366 |
+
# "-T",
|
367 |
+
# "--threshold",
|
368 |
+
# "--spectra-target-floor",
|
369 |
+
# type=float,
|
370 |
+
# default=1e-8,
|
371 |
+
# help="spectral threshold limit. v1 help string: Values in targets for dataset type spectra are replaced with this value, intended to be a small positive number used to enforce positive values.",
|
372 |
+
# )
|
373 |
+
train_args.add_argument(
|
374 |
+
"--metrics",
|
375 |
+
"--metric",
|
376 |
+
nargs="+",
|
377 |
+
action=LookupAction(MetricRegistry),
|
378 |
+
help="Specify the evaluation metrics. If unspecified, chemprop will use the following metrics for given dataset types: regression -> ``rmse``, classification -> ``roc``, multiclass -> ``ce`` ('cross entropy'), spectral -> ``sid``. If multiple metrics are provided, the 0-th one will be used for early stopping and checkpointing.",
|
379 |
+
)
|
380 |
+
train_args.add_argument(
|
381 |
+
"--tracking-metric",
|
382 |
+
default="val_loss",
|
383 |
+
help="The metric to track for early stopping and checkpointing. Defaults to the criterion used during training.",
|
384 |
+
)
|
385 |
+
train_args.add_argument(
|
386 |
+
"--show-individual-scores",
|
387 |
+
action="store_true",
|
388 |
+
help="Show all scores for individual targets, not just average, at the end.",
|
389 |
+
)
|
390 |
+
train_args.add_argument(
|
391 |
+
"--task-weights",
|
392 |
+
nargs="+",
|
393 |
+
type=float,
|
394 |
+
help="Weights to apply for whole tasks in the loss function",
|
395 |
+
)
|
396 |
+
train_args.add_argument(
|
397 |
+
"--warmup-epochs",
|
398 |
+
type=int,
|
399 |
+
default=2,
|
400 |
+
help="Number of epochs during which learning rate increases linearly from ``init_lr`` to ``max_lr`` (afterwards, learning rate decreases exponentially from ``max_lr`` to ``final_lr``)",
|
401 |
+
)
|
402 |
+
|
403 |
+
train_args.add_argument("--init-lr", type=float, default=1e-4, help="Initial learning rate.")
|
404 |
+
train_args.add_argument("--max-lr", type=float, default=1e-3, help="Maximum learning rate.")
|
405 |
+
train_args.add_argument("--final-lr", type=float, default=1e-4, help="Final learning rate.")
|
406 |
+
train_args.add_argument("--epochs", type=int, default=50, help="Number of epochs to train over")
|
407 |
+
train_args.add_argument(
|
408 |
+
"--patience",
|
409 |
+
type=int,
|
410 |
+
default=None,
|
411 |
+
help="Number of epochs to wait for improvement before early stopping",
|
412 |
+
)
|
413 |
+
train_args.add_argument(
|
414 |
+
"--grad-clip",
|
415 |
+
type=float,
|
416 |
+
help="Passed directly to the lightning trainer which controls grad clipping (see the ``Trainer()`` docstring for details)",
|
417 |
+
)
|
418 |
+
train_args.add_argument(
|
419 |
+
"--class-balance",
|
420 |
+
action="store_true",
|
421 |
+
help="Ensures each training batch contains an equal number of positive and negative samples.",
|
422 |
+
)
|
423 |
+
|
424 |
+
split_args = parser.add_argument_group("split args")
|
425 |
+
split_args.add_argument(
|
426 |
+
"--split",
|
427 |
+
"--split-type",
|
428 |
+
type=uppercase,
|
429 |
+
default="RANDOM",
|
430 |
+
choices=list(SplitType.keys()),
|
431 |
+
help="Method of splitting the data into train/val/test (case insensitive)",
|
432 |
+
)
|
433 |
+
split_args.add_argument(
|
434 |
+
"--split-sizes",
|
435 |
+
type=float,
|
436 |
+
nargs=3,
|
437 |
+
default=[0.8, 0.1, 0.1],
|
438 |
+
help="Split proportions for train/validation/test sets",
|
439 |
+
)
|
440 |
+
split_args.add_argument(
|
441 |
+
"--split-key-molecule",
|
442 |
+
type=int,
|
443 |
+
default=0,
|
444 |
+
help="Specify the index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used (e.g., ``scaffold_balanced`` or ``random_with_repeated_smiles``). Note that this index begins with zero for the first molecule.",
|
445 |
+
)
|
446 |
+
split_args.add_argument("--num-replicates", type=int, default=1, help="Number of replicates.")
|
447 |
+
split_args.add_argument("-k", "--num-folds", help=_CV_REMOVAL_ERROR)
|
448 |
+
split_args.add_argument(
|
449 |
+
"--save-smiles-splits",
|
450 |
+
action="store_true",
|
451 |
+
help="Whether to store the SMILES in each train/val/test split",
|
452 |
+
)
|
453 |
+
split_args.add_argument(
|
454 |
+
"--splits-file",
|
455 |
+
type=Path,
|
456 |
+
help="Path to a JSON file containing pre-defined splits for the input data, formatted as a list of dictionaries with keys ``train``, ``val``, and ``test`` and values as lists of indices or formatted strings (e.g. [0, 1, 2, 4] or '0-2,4')",
|
457 |
+
)
|
458 |
+
split_args.add_argument(
|
459 |
+
"--data-seed",
|
460 |
+
type=int,
|
461 |
+
default=0,
|
462 |
+
help="Specify the random seed to use when splitting data into train/val/test sets. When ``--num-replicates`` > 1, the first replicate uses this seed and all subsequent replicates add 1 to the seed (also used for shuffling data in ``build_dataloader`` when ``shuffle`` is True).",
|
463 |
+
)
|
464 |
+
|
465 |
+
parser.add_argument(
|
466 |
+
"--pytorch-seed",
|
467 |
+
type=int,
|
468 |
+
default=None,
|
469 |
+
help="Seed for PyTorch randomness (e.g., random initial weights)",
|
470 |
+
)
|
471 |
+
|
472 |
+
return parser
|
473 |
+
|
474 |
+
|
475 |
+
def process_train_args(args: Namespace) -> Namespace:
|
476 |
+
if args.output_dir is None:
|
477 |
+
args.output_dir = CHEMPROP_TRAIN_DIR / args.data_path.stem / NOW
|
478 |
+
|
479 |
+
return args
|
480 |
+
|
481 |
+
|
482 |
+
def validate_train_args(args):
|
483 |
+
if args.config_path is None and args.data_path is None:
|
484 |
+
raise ArgumentError(argument=None, message="Data path must be provided for training.")
|
485 |
+
|
486 |
+
if args.num_folds is not None: # i.e. user-specified
|
487 |
+
raise ArgumentError(argument=None, message=_CV_REMOVAL_ERROR)
|
488 |
+
|
489 |
+
if args.data_path.suffix not in [".csv"]:
|
490 |
+
raise ArgumentError(
|
491 |
+
argument=None, message=f"Input data must be a CSV file. Got {args.data_path}"
|
492 |
+
)
|
493 |
+
|
494 |
+
if args.epochs != -1 and args.epochs <= args.warmup_epochs:
|
495 |
+
raise ArgumentError(
|
496 |
+
argument=None,
|
497 |
+
message=f"The number of epochs should be higher than the number of epochs during warmup. Got {args.epochs} epochs and {args.warmup_epochs} warmup epochs",
|
498 |
+
)
|
499 |
+
|
500 |
+
# TODO: model_frzn is deprecated and then remove in v2.2
|
501 |
+
if args.checkpoint is not None and args.model_frzn is not None:
|
502 |
+
raise ArgumentError(
|
503 |
+
argument=None,
|
504 |
+
message="`--checkpoint` and `--model-frzn` cannot be used at the same time.",
|
505 |
+
)
|
506 |
+
|
507 |
+
if "--model-frzn" in sys.argv:
|
508 |
+
logger.warning(
|
509 |
+
"`--model-frzn` is deprecated and will be removed in v2.2. "
|
510 |
+
"Please use `--checkpoint` with `--freeze-encoder` instead."
|
511 |
+
)
|
512 |
+
|
513 |
+
if args.freeze_encoder and args.checkpoint is None:
|
514 |
+
raise ArgumentError(
|
515 |
+
argument=None,
|
516 |
+
message="`--freeze-encoder` can only be used when `--checkpoint` is used.",
|
517 |
+
)
|
518 |
+
|
519 |
+
if args.frzn_ffn_layers > 0:
|
520 |
+
if args.checkpoint is None and args.model_frzn is None:
|
521 |
+
raise ArgumentError(
|
522 |
+
argument=None,
|
523 |
+
message="`--frzn-ffn-layers` can only be used when `--checkpoint` or `--model-frzn` (depreciated in v2.1) is used.",
|
524 |
+
)
|
525 |
+
if args.checkpoint is not None and not args.freeze_encoder:
|
526 |
+
raise ArgumentError(
|
527 |
+
argument=None,
|
528 |
+
message="To freeze the first `n` layers of the FFN via `--frzn-ffn-layers`. The message passing layer should also be frozen with `--freeze-encoder`.",
|
529 |
+
)
|
530 |
+
|
531 |
+
if args.class_balance and args.task_type != "classification":
|
532 |
+
raise ArgumentError(
|
533 |
+
argument=None, message="Class balance is only applicable for classification tasks."
|
534 |
+
)
|
535 |
+
|
536 |
+
valid_tracking_metrics = (
|
537 |
+
args.metrics or [PredictorRegistry[args.task_type]._T_default_metric.alias]
|
538 |
+
) + ["val_loss"]
|
539 |
+
if args.tracking_metric not in valid_tracking_metrics:
|
540 |
+
raise ArgumentError(
|
541 |
+
argument=None,
|
542 |
+
message=f"Tracking metric must be one of {','.join(valid_tracking_metrics)}. "
|
543 |
+
f"Got {args.tracking_metric}. Additional tracking metric options can be specified with "
|
544 |
+
"the `--metrics` flag.",
|
545 |
+
)
|
546 |
+
|
547 |
+
input_cols, target_cols = get_column_names(
|
548 |
+
args.data_path,
|
549 |
+
args.smiles_columns,
|
550 |
+
args.reaction_columns,
|
551 |
+
args.target_columns,
|
552 |
+
args.ignore_columns,
|
553 |
+
args.splits_column,
|
554 |
+
args.weight_column,
|
555 |
+
args.no_header_row,
|
556 |
+
)
|
557 |
+
|
558 |
+
args.input_columns = input_cols
|
559 |
+
args.target_columns = target_cols
|
560 |
+
|
561 |
+
return args
|
562 |
+
|
563 |
+
|
564 |
+
def normalize_inputs(train_dset, val_dset, args):
|
565 |
+
multicomponent = isinstance(train_dset, MulticomponentDataset)
|
566 |
+
num_components = train_dset.n_components if multicomponent else 1
|
567 |
+
|
568 |
+
X_d_transform = None
|
569 |
+
V_f_transforms = [nn.Identity()] * num_components
|
570 |
+
E_f_transforms = [nn.Identity()] * num_components
|
571 |
+
V_d_transforms = [None] * num_components
|
572 |
+
graph_transforms = []
|
573 |
+
|
574 |
+
d_xd = train_dset.d_xd
|
575 |
+
d_vf = train_dset.d_vf
|
576 |
+
d_ef = train_dset.d_ef
|
577 |
+
d_vd = train_dset.d_vd
|
578 |
+
|
579 |
+
if d_xd > 0 and not args.no_descriptor_scaling:
|
580 |
+
scaler = train_dset.normalize_inputs("X_d")
|
581 |
+
val_dset.normalize_inputs("X_d", scaler)
|
582 |
+
|
583 |
+
scaler = scaler if not isinstance(scaler, list) else scaler[0]
|
584 |
+
|
585 |
+
if scaler is not None:
|
586 |
+
logger.info(
|
587 |
+
f"Descriptors: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
588 |
+
)
|
589 |
+
X_d_transform = ScaleTransform.from_standard_scaler(scaler)
|
590 |
+
|
591 |
+
if d_vf > 0 and not args.no_atom_feature_scaling:
|
592 |
+
scaler = train_dset.normalize_inputs("V_f")
|
593 |
+
val_dset.normalize_inputs("V_f", scaler)
|
594 |
+
|
595 |
+
scalers = [scaler] if not isinstance(scaler, list) else scaler
|
596 |
+
|
597 |
+
for i, scaler in enumerate(scalers):
|
598 |
+
if scaler is None:
|
599 |
+
continue
|
600 |
+
|
601 |
+
logger.info(
|
602 |
+
f"Atom features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
603 |
+
)
|
604 |
+
featurizer = (
|
605 |
+
train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
|
606 |
+
)
|
607 |
+
V_f_transforms[i] = ScaleTransform.from_standard_scaler(
|
608 |
+
scaler, pad=featurizer.atom_fdim - featurizer.extra_atom_fdim
|
609 |
+
)
|
610 |
+
|
611 |
+
if d_ef > 0 and not args.no_bond_feature_scaling:
|
612 |
+
scaler = train_dset.normalize_inputs("E_f")
|
613 |
+
val_dset.normalize_inputs("E_f", scaler)
|
614 |
+
|
615 |
+
scalers = [scaler] if not isinstance(scaler, list) else scaler
|
616 |
+
|
617 |
+
for i, scaler in enumerate(scalers):
|
618 |
+
if scaler is None:
|
619 |
+
continue
|
620 |
+
|
621 |
+
logger.info(
|
622 |
+
f"Bond features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
623 |
+
)
|
624 |
+
featurizer = (
|
625 |
+
train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
|
626 |
+
)
|
627 |
+
E_f_transforms[i] = ScaleTransform.from_standard_scaler(
|
628 |
+
scaler, pad=featurizer.bond_fdim - featurizer.extra_bond_fdim
|
629 |
+
)
|
630 |
+
|
631 |
+
for V_f_transform, E_f_transform in zip(V_f_transforms, E_f_transforms):
|
632 |
+
graph_transforms.append(GraphTransform(V_f_transform, E_f_transform))
|
633 |
+
|
634 |
+
if d_vd > 0 and not args.no_atom_descriptor_scaling:
|
635 |
+
scaler = train_dset.normalize_inputs("V_d")
|
636 |
+
val_dset.normalize_inputs("V_d", scaler)
|
637 |
+
|
638 |
+
scalers = [scaler] if not isinstance(scaler, list) else scaler
|
639 |
+
|
640 |
+
for i, scaler in enumerate(scalers):
|
641 |
+
if scaler is None:
|
642 |
+
continue
|
643 |
+
|
644 |
+
logger.info(
|
645 |
+
f"Atom descriptors for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
|
646 |
+
)
|
647 |
+
V_d_transforms[i] = ScaleTransform.from_standard_scaler(scaler)
|
648 |
+
|
649 |
+
return X_d_transform, graph_transforms, V_d_transforms
|
650 |
+
|
651 |
+
|
652 |
+
def load_and_use_pretrained_model_scalers(model_path: Path, train_dset, val_dset) -> None:
|
653 |
+
if isinstance(train_dset, MulticomponentDataset):
|
654 |
+
_model = MulticomponentMPNN.load_from_file(model_path)
|
655 |
+
blocks = _model.message_passing.blocks
|
656 |
+
train_dsets = train_dset.datasets
|
657 |
+
val_dsets = val_dset.datasets
|
658 |
+
else:
|
659 |
+
_model = MPNN.load_from_file(model_path)
|
660 |
+
blocks = [_model.message_passing]
|
661 |
+
train_dsets = [train_dset]
|
662 |
+
val_dsets = [val_dset]
|
663 |
+
|
664 |
+
for i in range(len(blocks)):
|
665 |
+
if isinstance(_model.X_d_transform, ScaleTransform):
|
666 |
+
scaler = _model.X_d_transform.to_standard_scaler()
|
667 |
+
train_dsets[i].normalize_inputs("X_d", scaler)
|
668 |
+
val_dsets[i].normalize_inputs("X_d", scaler)
|
669 |
+
|
670 |
+
if isinstance(blocks[i].graph_transform, GraphTransform):
|
671 |
+
if isinstance(blocks[i].graph_transform.V_transform, ScaleTransform):
|
672 |
+
V_anti_pad = (
|
673 |
+
train_dsets[i].featurizer.atom_fdim - train_dsets[i].featurizer.extra_atom_fdim
|
674 |
+
)
|
675 |
+
scaler = blocks[i].graph_transform.V_transform.to_standard_scaler(
|
676 |
+
anti_pad=V_anti_pad
|
677 |
+
)
|
678 |
+
train_dsets[i].normalize_inputs("V_f", scaler)
|
679 |
+
val_dsets[i].normalize_inputs("V_f", scaler)
|
680 |
+
if isinstance(blocks[i].graph_transform.E_transform, ScaleTransform):
|
681 |
+
E_anti_pad = (
|
682 |
+
train_dsets[i].featurizer.bond_fdim - train_dsets[i].featurizer.extra_bond_fdim
|
683 |
+
)
|
684 |
+
scaler = blocks[i].graph_transform.E_transform.to_standard_scaler(
|
685 |
+
anti_pad=E_anti_pad
|
686 |
+
)
|
687 |
+
train_dsets[i].normalize_inputs("E_f", scaler)
|
688 |
+
val_dsets[i].normalize_inputs("E_f", scaler)
|
689 |
+
|
690 |
+
if isinstance(blocks[i].V_d_transform, ScaleTransform):
|
691 |
+
scaler = blocks[i].V_d_transform.to_standard_scaler()
|
692 |
+
train_dsets[i].normalize_inputs("V_d", scaler)
|
693 |
+
val_dsets[i].normalize_inputs("V_d", scaler)
|
694 |
+
|
695 |
+
if isinstance(_model.predictor.output_transform, UnscaleTransform):
|
696 |
+
scaler = _model.predictor.output_transform.to_standard_scaler()
|
697 |
+
train_dset.normalize_targets(scaler)
|
698 |
+
val_dset.normalize_targets(scaler)
|
699 |
+
|
700 |
+
|
701 |
+
def save_config(parser: ArgumentParser, args: Namespace, config_path: Path):
|
702 |
+
config_args = deepcopy(args)
|
703 |
+
for key, value in vars(config_args).items():
|
704 |
+
if isinstance(value, Path):
|
705 |
+
setattr(config_args, key, str(value))
|
706 |
+
|
707 |
+
for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
|
708 |
+
if getattr(config_args, key) is not None:
|
709 |
+
for index, path in getattr(config_args, key).items():
|
710 |
+
getattr(config_args, key)[index] = str(path)
|
711 |
+
|
712 |
+
parser.write_config_file(parsed_namespace=config_args, output_file_paths=[str(config_path)])
|
713 |
+
|
714 |
+
|
715 |
+
def save_smiles_splits(args: Namespace, output_dir, train_dset, val_dset, test_dset):
|
716 |
+
match (args.smiles_columns, args.reaction_columns):
|
717 |
+
case [_, None]:
|
718 |
+
column_labels = deepcopy(args.smiles_columns)
|
719 |
+
case [None, _]:
|
720 |
+
column_labels = deepcopy(args.reaction_columns)
|
721 |
+
case _:
|
722 |
+
column_labels = deepcopy(args.smiles_columns)
|
723 |
+
column_labels.extend(args.reaction_columns)
|
724 |
+
|
725 |
+
train_smis = train_dset.names
|
726 |
+
df_train = pd.DataFrame(train_smis, columns=column_labels)
|
727 |
+
df_train.to_csv(output_dir / "train_smiles.csv", index=False)
|
728 |
+
|
729 |
+
val_smis = val_dset.names
|
730 |
+
df_val = pd.DataFrame(val_smis, columns=column_labels)
|
731 |
+
df_val.to_csv(output_dir / "val_smiles.csv", index=False)
|
732 |
+
|
733 |
+
if test_dset is not None:
|
734 |
+
test_smis = test_dset.names
|
735 |
+
df_test = pd.DataFrame(test_smis, columns=column_labels)
|
736 |
+
df_test.to_csv(output_dir / "test_smiles.csv", index=False)
|
737 |
+
|
738 |
+
|
739 |
+
def build_splits(args, format_kwargs, featurization_kwargs):
|
740 |
+
"""build the train/val/test splits"""
|
741 |
+
logger.info(f"Pulling data from file: {args.data_path}")
|
742 |
+
all_data = build_data_from_files(
|
743 |
+
args.data_path,
|
744 |
+
p_descriptors=args.descriptors_path,
|
745 |
+
p_atom_feats=args.atom_features_path,
|
746 |
+
p_bond_feats=args.bond_features_path,
|
747 |
+
p_atom_descs=args.atom_descriptors_path,
|
748 |
+
**format_kwargs,
|
749 |
+
**featurization_kwargs,
|
750 |
+
)
|
751 |
+
|
752 |
+
if args.splits_column is not None:
|
753 |
+
df = pd.read_csv(
|
754 |
+
args.data_path, header=None if args.no_header_row else "infer", index_col=False
|
755 |
+
)
|
756 |
+
grouped = df.groupby(df[args.splits_column].str.lower())
|
757 |
+
train_indices = grouped.groups.get("train", pd.Index([])).tolist()
|
758 |
+
val_indices = grouped.groups.get("val", pd.Index([])).tolist()
|
759 |
+
test_indices = grouped.groups.get("test", pd.Index([])).tolist()
|
760 |
+
train_indices, val_indices, test_indices = [train_indices], [val_indices], [test_indices]
|
761 |
+
|
762 |
+
elif args.splits_file is not None:
|
763 |
+
with open(args.splits_file, "rb") as json_file:
|
764 |
+
split_idxss = json.load(json_file)
|
765 |
+
train_indices = [parse_indices(d["train"]) for d in split_idxss]
|
766 |
+
val_indices = [parse_indices(d["val"]) for d in split_idxss]
|
767 |
+
test_indices = [parse_indices(d["test"]) for d in split_idxss]
|
768 |
+
args.num_replicates = len(split_idxss)
|
769 |
+
|
770 |
+
else:
|
771 |
+
splitting_data = all_data[args.split_key_molecule]
|
772 |
+
if isinstance(splitting_data[0], ReactionDatapoint):
|
773 |
+
splitting_mols = [datapoint.rct for datapoint in splitting_data]
|
774 |
+
else:
|
775 |
+
splitting_mols = [datapoint.mol for datapoint in splitting_data]
|
776 |
+
train_indices, val_indices, test_indices = make_split_indices(
|
777 |
+
splitting_mols, args.split, args.split_sizes, args.data_seed, args.num_replicates
|
778 |
+
)
|
779 |
+
|
780 |
+
train_data, val_data, test_data = split_data_by_indices(
|
781 |
+
all_data, train_indices, val_indices, test_indices
|
782 |
+
)
|
783 |
+
for i_split in range(len(train_data)):
|
784 |
+
sizes = [len(train_data[i_split][0]), len(val_data[i_split][0]), len(test_data[i_split][0])]
|
785 |
+
logger.info(f"train/val/test split_{i_split} sizes: {sizes}")
|
786 |
+
|
787 |
+
return train_data, val_data, test_data
|
788 |
+
|
789 |
+
|
790 |
+
def summarize(
|
791 |
+
target_cols: list[str], task_type: str, dataset: _MolGraphDatasetMixin
|
792 |
+
) -> tuple[list, list]:
|
793 |
+
if task_type in [
|
794 |
+
"regression",
|
795 |
+
"regression-mve",
|
796 |
+
"regression-evidential",
|
797 |
+
"regression-quantile",
|
798 |
+
]:
|
799 |
+
if isinstance(dataset, MulticomponentDataset):
|
800 |
+
y = dataset.datasets[0].Y
|
801 |
+
else:
|
802 |
+
y = dataset.Y
|
803 |
+
y_mean = np.nanmean(y, axis=0)
|
804 |
+
y_std = np.nanstd(y, axis=0)
|
805 |
+
y_median = np.nanmedian(y, axis=0)
|
806 |
+
mean_dev_abs = np.abs(y - y_mean)
|
807 |
+
num_targets = np.sum(~np.isnan(y), axis=0)
|
808 |
+
frac_1_sigma = np.sum((mean_dev_abs < y_std), axis=0) / num_targets
|
809 |
+
frac_2_sigma = np.sum((mean_dev_abs < 2 * y_std), axis=0) / num_targets
|
810 |
+
|
811 |
+
column_headers = ["Statistic"] + [f"Value ({target_cols[i]})" for i in range(y.shape[1])]
|
812 |
+
table_rows = [
|
813 |
+
["Num. smiles"] + [f"{len(y)}" for i in range(y.shape[1])],
|
814 |
+
["Num. targets"] + [f"{num_targets[i]}" for i in range(y.shape[1])],
|
815 |
+
["Num. NaN"] + [f"{len(y) - num_targets[i]}" for i in range(y.shape[1])],
|
816 |
+
["Mean"] + [f"{mean:0.3g}" for mean in y_mean],
|
817 |
+
["Std. dev."] + [f"{std:0.3g}" for std in y_std],
|
818 |
+
["Median"] + [f"{median:0.3g}" for median in y_median],
|
819 |
+
["% within 1 s.d."] + [f"{sigma:0.0%}" for sigma in frac_1_sigma],
|
820 |
+
["% within 2 s.d."] + [f"{sigma:0.0%}" for sigma in frac_2_sigma],
|
821 |
+
]
|
822 |
+
return (column_headers, table_rows)
|
823 |
+
elif task_type in [
|
824 |
+
"classification",
|
825 |
+
"classification-dirichlet",
|
826 |
+
"multiclass",
|
827 |
+
"multiclass-dirichlet",
|
828 |
+
]:
|
829 |
+
if isinstance(dataset, MulticomponentDataset):
|
830 |
+
y = dataset.datasets[0].Y
|
831 |
+
else:
|
832 |
+
y = dataset.Y
|
833 |
+
|
834 |
+
mask = np.isnan(y)
|
835 |
+
classes = np.sort(np.unique(y[~mask]))
|
836 |
+
|
837 |
+
class_counts = np.stack([(classes[:, None] == y[:, i]).sum(1) for i in range(y.shape[1])])
|
838 |
+
class_fracs = class_counts / y.shape[0]
|
839 |
+
nan_count = np.nansum(mask, axis=0)
|
840 |
+
nan_frac = nan_count / y.shape[0]
|
841 |
+
|
842 |
+
column_headers = ["Class"] + [f"Count/Percent {target_cols[i]}" for i in range(y.shape[1])]
|
843 |
+
|
844 |
+
table_rows = [
|
845 |
+
[f"{k}"] + [f"{class_counts[j, i]}/{class_fracs[j, i]:0.0%}" for j in range(y.shape[1])]
|
846 |
+
for i, k in enumerate(classes)
|
847 |
+
]
|
848 |
+
|
849 |
+
nan_row = ["NaN"] + [f"{nan_count[i]}/{nan_frac[i]:0.0%}" for i in range(y.shape[1])]
|
850 |
+
table_rows.append(nan_row)
|
851 |
+
|
852 |
+
total_row = ["Total"] + [f"{y.shape[0]}/{100.00}%" for i in range(y.shape[1])]
|
853 |
+
table_rows.append(total_row)
|
854 |
+
|
855 |
+
return (column_headers, table_rows)
|
856 |
+
else:
|
857 |
+
raise ValueError(f"unsupported task type! Task type '{task_type}' was not recognized.")
|
858 |
+
|
859 |
+
|
860 |
+
def build_table(column_headers: list[str], table_rows: list[str], title: str | None = None) -> str:
|
861 |
+
right_justified_columns = [
|
862 |
+
Column(header=column_header, justify="right") for column_header in column_headers
|
863 |
+
]
|
864 |
+
table = Table(*right_justified_columns, title=title)
|
865 |
+
for row in table_rows:
|
866 |
+
table.add_row(*row)
|
867 |
+
|
868 |
+
console = Console(record=True, file=StringIO(), width=200)
|
869 |
+
console.print(table)
|
870 |
+
return console.export_text()
|
871 |
+
|
872 |
+
|
873 |
+
def build_datasets(args, train_data, val_data, test_data):
|
874 |
+
"""build the train/val/test datasets, where :attr:`test_data` may be None"""
|
875 |
+
multicomponent = len(train_data) > 1
|
876 |
+
if multicomponent:
|
877 |
+
train_dsets = [
|
878 |
+
make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
879 |
+
for data in train_data
|
880 |
+
]
|
881 |
+
val_dsets = [
|
882 |
+
make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
883 |
+
for data in val_data
|
884 |
+
]
|
885 |
+
train_dset = MulticomponentDataset(train_dsets)
|
886 |
+
val_dset = MulticomponentDataset(val_dsets)
|
887 |
+
if len(test_data[0]) > 0:
|
888 |
+
test_dsets = [
|
889 |
+
make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
890 |
+
for data in test_data
|
891 |
+
]
|
892 |
+
test_dset = MulticomponentDataset(test_dsets)
|
893 |
+
else:
|
894 |
+
test_dset = None
|
895 |
+
else:
|
896 |
+
train_data = train_data[0]
|
897 |
+
val_data = val_data[0]
|
898 |
+
test_data = test_data[0]
|
899 |
+
train_dset = make_dataset(train_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
900 |
+
val_dset = make_dataset(val_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
901 |
+
if len(test_data) > 0:
|
902 |
+
test_dset = make_dataset(test_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
|
903 |
+
else:
|
904 |
+
test_dset = None
|
905 |
+
if args.task_type != "spectral":
|
906 |
+
for dataset, label in zip(
|
907 |
+
[train_dset, val_dset, test_dset], ["Training", "Validation", "Test"]
|
908 |
+
):
|
909 |
+
column_headers, table_rows = summarize(args.target_columns, args.task_type, dataset)
|
910 |
+
output = build_table(column_headers, table_rows, f"Summary of {label} Data")
|
911 |
+
logger.info("\n" + output)
|
912 |
+
|
913 |
+
return train_dset, val_dset, test_dset
|
914 |
+
|
915 |
+
|
916 |
+
def build_model(
|
917 |
+
args,
|
918 |
+
train_dset: MolGraphDataset | MulticomponentDataset,
|
919 |
+
output_transform: UnscaleTransform,
|
920 |
+
input_transforms: tuple[ScaleTransform, list[GraphTransform], list[ScaleTransform]],
|
921 |
+
) -> MPNN:
|
922 |
+
mp_cls = AtomMessagePassing if args.atom_messages else BondMessagePassing
|
923 |
+
|
924 |
+
X_d_transform, graph_transforms, V_d_transforms = input_transforms
|
925 |
+
if isinstance(train_dset, MulticomponentDataset):
|
926 |
+
mp_blocks = [
|
927 |
+
mp_cls(
|
928 |
+
train_dset.datasets[i].featurizer.atom_fdim,
|
929 |
+
train_dset.datasets[i].featurizer.bond_fdim,
|
930 |
+
d_h=args.message_hidden_dim,
|
931 |
+
d_vd=(
|
932 |
+
train_dset.datasets[i].d_vd
|
933 |
+
if isinstance(train_dset.datasets[i], MoleculeDataset)
|
934 |
+
else 0
|
935 |
+
),
|
936 |
+
bias=args.message_bias,
|
937 |
+
depth=args.depth,
|
938 |
+
undirected=args.undirected,
|
939 |
+
dropout=args.dropout,
|
940 |
+
activation=args.activation,
|
941 |
+
V_d_transform=V_d_transforms[i],
|
942 |
+
graph_transform=graph_transforms[i],
|
943 |
+
)
|
944 |
+
for i in range(train_dset.n_components)
|
945 |
+
]
|
946 |
+
if args.mpn_shared:
|
947 |
+
if args.reaction_columns is not None and args.smiles_columns is not None:
|
948 |
+
raise ArgumentError(
|
949 |
+
argument=None,
|
950 |
+
message="Cannot use shared MPNN with both molecule and reaction data.",
|
951 |
+
)
|
952 |
+
|
953 |
+
mp_block = MulticomponentMessagePassing(mp_blocks, train_dset.n_components, args.mpn_shared)
|
954 |
+
# NOTE(degraff): this if/else block should be handled by the init of MulticomponentMessagePassing
|
955 |
+
# if args.mpn_shared:
|
956 |
+
# mp_block = MulticomponentMessagePassing(mp_blocks[0], n_components, args.mpn_shared)
|
957 |
+
# else:
|
958 |
+
d_xd = train_dset.datasets[0].d_xd
|
959 |
+
n_tasks = train_dset.datasets[0].Y.shape[1]
|
960 |
+
mpnn_cls = MulticomponentMPNN
|
961 |
+
else:
|
962 |
+
mp_block = mp_cls(
|
963 |
+
train_dset.featurizer.atom_fdim,
|
964 |
+
train_dset.featurizer.bond_fdim,
|
965 |
+
d_h=args.message_hidden_dim,
|
966 |
+
d_vd=train_dset.d_vd if isinstance(train_dset, MoleculeDataset) else 0,
|
967 |
+
bias=args.message_bias,
|
968 |
+
depth=args.depth,
|
969 |
+
undirected=args.undirected,
|
970 |
+
dropout=args.dropout,
|
971 |
+
activation=args.activation,
|
972 |
+
V_d_transform=V_d_transforms[0],
|
973 |
+
graph_transform=graph_transforms[0],
|
974 |
+
)
|
975 |
+
d_xd = train_dset.d_xd
|
976 |
+
n_tasks = train_dset.Y.shape[1]
|
977 |
+
mpnn_cls = MPNN
|
978 |
+
|
979 |
+
agg = Factory.build(AggregationRegistry[args.aggregation], norm=args.aggregation_norm)
|
980 |
+
predictor_cls = PredictorRegistry[args.task_type]
|
981 |
+
if args.loss_function is not None:
|
982 |
+
task_weights = torch.ones(n_tasks) if args.task_weights is None else args.task_weights
|
983 |
+
criterion = Factory.build(
|
984 |
+
LossFunctionRegistry[args.loss_function],
|
985 |
+
task_weights=task_weights,
|
986 |
+
v_kl=args.v_kl,
|
987 |
+
# threshold=args.threshold, TODO: Add in v2.1
|
988 |
+
eps=args.eps,
|
989 |
+
alpha=args.alpha,
|
990 |
+
)
|
991 |
+
else:
|
992 |
+
criterion = None
|
993 |
+
if args.metrics is not None:
|
994 |
+
metrics = [Factory.build(MetricRegistry[metric]) for metric in args.metrics]
|
995 |
+
else:
|
996 |
+
metrics = None
|
997 |
+
|
998 |
+
predictor = Factory.build(
|
999 |
+
predictor_cls,
|
1000 |
+
input_dim=mp_block.output_dim + d_xd,
|
1001 |
+
n_tasks=n_tasks,
|
1002 |
+
hidden_dim=args.ffn_hidden_dim,
|
1003 |
+
n_layers=args.ffn_num_layers,
|
1004 |
+
dropout=args.dropout,
|
1005 |
+
activation=args.activation,
|
1006 |
+
criterion=criterion,
|
1007 |
+
task_weights=args.task_weights,
|
1008 |
+
n_classes=args.multiclass_num_classes,
|
1009 |
+
output_transform=output_transform,
|
1010 |
+
# spectral_activation=args.spectral_activation, TODO: Add in v2.1
|
1011 |
+
)
|
1012 |
+
|
1013 |
+
if args.loss_function is None:
|
1014 |
+
logger.info(
|
1015 |
+
f"No loss function was specified! Using class default: {predictor_cls._T_default_criterion}"
|
1016 |
+
)
|
1017 |
+
|
1018 |
+
return mpnn_cls(
|
1019 |
+
mp_block,
|
1020 |
+
agg,
|
1021 |
+
predictor,
|
1022 |
+
args.batch_norm,
|
1023 |
+
metrics,
|
1024 |
+
args.warmup_epochs,
|
1025 |
+
args.init_lr,
|
1026 |
+
args.max_lr,
|
1027 |
+
args.final_lr,
|
1028 |
+
X_d_transform=X_d_transform,
|
1029 |
+
)
|
1030 |
+
|
1031 |
+
|
1032 |
+
def train_model(
|
1033 |
+
args, train_loader, val_loader, test_loader, output_dir, output_transform, input_transforms
|
1034 |
+
):
|
1035 |
+
if args.checkpoint is not None:
|
1036 |
+
model_paths = find_models(args.checkpoint)
|
1037 |
+
if args.ensemble_size != len(model_paths):
|
1038 |
+
logger.warning(
|
1039 |
+
f"The number of models in ensemble for each splitting of data is set to {len(model_paths)}."
|
1040 |
+
)
|
1041 |
+
args.ensemble_size = len(model_paths)
|
1042 |
+
|
1043 |
+
for model_idx in range(args.ensemble_size):
|
1044 |
+
model_output_dir = output_dir / f"model_{model_idx}"
|
1045 |
+
model_output_dir.mkdir(exist_ok=True, parents=True)
|
1046 |
+
|
1047 |
+
if args.pytorch_seed is None:
|
1048 |
+
seed = torch.seed()
|
1049 |
+
deterministic = False
|
1050 |
+
else:
|
1051 |
+
seed = args.pytorch_seed + model_idx
|
1052 |
+
deterministic = True
|
1053 |
+
|
1054 |
+
torch.manual_seed(seed)
|
1055 |
+
|
1056 |
+
if args.checkpoint or args.model_frzn is not None:
|
1057 |
+
mpnn_cls = (
|
1058 |
+
MulticomponentMPNN
|
1059 |
+
if isinstance(train_loader.dataset, MulticomponentDataset)
|
1060 |
+
else MPNN
|
1061 |
+
)
|
1062 |
+
model_path = model_paths[model_idx] if args.checkpoint else args.model_frzn
|
1063 |
+
model = mpnn_cls.load_from_file(model_path)
|
1064 |
+
|
1065 |
+
if args.checkpoint:
|
1066 |
+
model.apply(
|
1067 |
+
lambda m: setattr(m, "p", args.dropout)
|
1068 |
+
if isinstance(m, torch.nn.Dropout)
|
1069 |
+
else None
|
1070 |
+
)
|
1071 |
+
|
1072 |
+
# TODO: model_frzn is deprecated and then remove in v2.2
|
1073 |
+
if args.model_frzn or args.freeze_encoder:
|
1074 |
+
model.message_passing.apply(lambda module: module.requires_grad_(False))
|
1075 |
+
model.message_passing.eval()
|
1076 |
+
model.bn.apply(lambda module: module.requires_grad_(False))
|
1077 |
+
model.bn.eval()
|
1078 |
+
for idx in range(args.frzn_ffn_layers):
|
1079 |
+
model.predictor.ffn[idx].requires_grad_(False)
|
1080 |
+
model.predictor.ffn[idx + 1].eval()
|
1081 |
+
else:
|
1082 |
+
model = build_model(args, train_loader.dataset, output_transform, input_transforms)
|
1083 |
+
logger.info(model)
|
1084 |
+
|
1085 |
+
try:
|
1086 |
+
trainer_logger = TensorBoardLogger(
|
1087 |
+
model_output_dir, "trainer_logs", default_hp_metric=False
|
1088 |
+
)
|
1089 |
+
except ModuleNotFoundError as e:
|
1090 |
+
logger.warning(
|
1091 |
+
f"Unable to import TensorBoardLogger, reverting to CSVLogger (original error: {e})."
|
1092 |
+
)
|
1093 |
+
trainer_logger = CSVLogger(model_output_dir, "trainer_logs")
|
1094 |
+
|
1095 |
+
if args.tracking_metric == "val_loss":
|
1096 |
+
T_tracking_metric = model.criterion.__class__
|
1097 |
+
tracking_metric = args.tracking_metric
|
1098 |
+
else:
|
1099 |
+
T_tracking_metric = MetricRegistry[args.tracking_metric]
|
1100 |
+
tracking_metric = "val/" + args.tracking_metric
|
1101 |
+
|
1102 |
+
monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
|
1103 |
+
logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
|
1104 |
+
|
1105 |
+
if args.remove_checkpoints:
|
1106 |
+
temp_dir = TemporaryDirectory()
|
1107 |
+
checkpoint_dir = Path(temp_dir.name)
|
1108 |
+
else:
|
1109 |
+
checkpoint_dir = model_output_dir
|
1110 |
+
|
1111 |
+
checkpoint_filename = (
|
1112 |
+
f"best-epoch={{epoch}}-{tracking_metric.replace('/', '_')}="
|
1113 |
+
f"{{{tracking_metric}:.2f}}"
|
1114 |
+
)
|
1115 |
+
checkpointing = ModelCheckpoint(
|
1116 |
+
checkpoint_dir / "checkpoints",
|
1117 |
+
checkpoint_filename,
|
1118 |
+
tracking_metric,
|
1119 |
+
mode=monitor_mode,
|
1120 |
+
save_last=True,
|
1121 |
+
auto_insert_metric_name=False,
|
1122 |
+
)
|
1123 |
+
|
1124 |
+
if args.epochs != -1:
|
1125 |
+
patience = args.patience if args.patience is not None else args.epochs
|
1126 |
+
early_stopping = EarlyStopping(tracking_metric, patience=patience, mode=monitor_mode)
|
1127 |
+
callbacks = [checkpointing, early_stopping]
|
1128 |
+
else:
|
1129 |
+
callbacks = [checkpointing]
|
1130 |
+
|
1131 |
+
trainer = pl.Trainer(
|
1132 |
+
logger=trainer_logger,
|
1133 |
+
enable_progress_bar=True,
|
1134 |
+
accelerator=args.accelerator,
|
1135 |
+
devices=args.devices,
|
1136 |
+
max_epochs=args.epochs,
|
1137 |
+
callbacks=callbacks,
|
1138 |
+
gradient_clip_val=args.grad_clip,
|
1139 |
+
deterministic=deterministic,
|
1140 |
+
)
|
1141 |
+
trainer.fit(model, train_loader, val_loader)
|
1142 |
+
|
1143 |
+
if test_loader is not None:
|
1144 |
+
if isinstance(trainer.strategy, DDPStrategy):
|
1145 |
+
torch.distributed.destroy_process_group()
|
1146 |
+
|
1147 |
+
best_ckpt_path = trainer.checkpoint_callback.best_model_path
|
1148 |
+
trainer = pl.Trainer(
|
1149 |
+
logger=trainer_logger,
|
1150 |
+
enable_progress_bar=True,
|
1151 |
+
accelerator=args.accelerator,
|
1152 |
+
devices=1,
|
1153 |
+
)
|
1154 |
+
model = model.load_from_checkpoint(best_ckpt_path)
|
1155 |
+
predss = trainer.predict(model, dataloaders=test_loader)
|
1156 |
+
else:
|
1157 |
+
predss = trainer.predict(dataloaders=test_loader)
|
1158 |
+
|
1159 |
+
preds = torch.concat(predss, 0)
|
1160 |
+
if model.predictor.n_targets > 1:
|
1161 |
+
preds = preds[..., 0]
|
1162 |
+
preds = preds.numpy()
|
1163 |
+
|
1164 |
+
evaluate_and_save_predictions(
|
1165 |
+
preds, test_loader, model.metrics[:-1], model_output_dir, args
|
1166 |
+
)
|
1167 |
+
|
1168 |
+
best_model_path = checkpointing.best_model_path
|
1169 |
+
model = model.__class__.load_from_checkpoint(best_model_path)
|
1170 |
+
p_model = model_output_dir / "best.pt"
|
1171 |
+
save_model(p_model, model, args.target_columns)
|
1172 |
+
logger.info(f"Best model saved to '{p_model}'")
|
1173 |
+
|
1174 |
+
if args.remove_checkpoints:
|
1175 |
+
temp_dir.cleanup()
|
1176 |
+
|
1177 |
+
|
1178 |
+
def evaluate_and_save_predictions(preds, test_loader, metrics, model_output_dir, args):
|
1179 |
+
if isinstance(test_loader.dataset, MulticomponentDataset):
|
1180 |
+
test_dset = test_loader.dataset.datasets[0]
|
1181 |
+
else:
|
1182 |
+
test_dset = test_loader.dataset
|
1183 |
+
targets = test_dset.Y
|
1184 |
+
mask = torch.from_numpy(np.isfinite(targets))
|
1185 |
+
targets = np.nan_to_num(targets, nan=0.0)
|
1186 |
+
weights = torch.ones(len(test_dset))
|
1187 |
+
lt_mask = torch.from_numpy(test_dset.lt_mask) if test_dset.lt_mask[0] is not None else None
|
1188 |
+
gt_mask = torch.from_numpy(test_dset.gt_mask) if test_dset.gt_mask[0] is not None else None
|
1189 |
+
|
1190 |
+
individual_scores = dict()
|
1191 |
+
for metric in metrics:
|
1192 |
+
individual_scores[metric.alias] = []
|
1193 |
+
for i, col in enumerate(args.target_columns):
|
1194 |
+
if "multiclass" in args.task_type:
|
1195 |
+
preds_slice = torch.from_numpy(preds[:, i : i + 1, :])
|
1196 |
+
targets_slice = torch.from_numpy(targets[:, i : i + 1])
|
1197 |
+
else:
|
1198 |
+
preds_slice = torch.from_numpy(preds[:, i : i + 1])
|
1199 |
+
targets_slice = torch.from_numpy(targets[:, i : i + 1])
|
1200 |
+
preds_loss = metric(
|
1201 |
+
preds_slice,
|
1202 |
+
targets_slice,
|
1203 |
+
mask[:, i : i + 1],
|
1204 |
+
weights,
|
1205 |
+
lt_mask[:, i] if lt_mask is not None else None,
|
1206 |
+
gt_mask[:, i] if gt_mask is not None else None,
|
1207 |
+
)
|
1208 |
+
individual_scores[metric.alias].append(preds_loss)
|
1209 |
+
|
1210 |
+
logger.info("Test Set results:")
|
1211 |
+
for metric in metrics:
|
1212 |
+
avg_loss = sum(individual_scores[metric.alias]) / len(individual_scores[metric.alias])
|
1213 |
+
logger.info(f"test/{metric.alias}: {avg_loss}")
|
1214 |
+
|
1215 |
+
if args.show_individual_scores:
|
1216 |
+
logger.info("Entire Test Set individual results:")
|
1217 |
+
for metric in metrics:
|
1218 |
+
for i, col in enumerate(args.target_columns):
|
1219 |
+
logger.info(f"test/{col}/{metric.alias}: {individual_scores[metric.alias][i]}")
|
1220 |
+
|
1221 |
+
names = test_loader.dataset.names
|
1222 |
+
if isinstance(test_loader.dataset, MulticomponentDataset):
|
1223 |
+
namess = list(zip(*names))
|
1224 |
+
else:
|
1225 |
+
namess = [names]
|
1226 |
+
|
1227 |
+
columns = args.input_columns + args.target_columns
|
1228 |
+
if "multiclass" in args.task_type:
|
1229 |
+
columns = columns + [f"{col}_prob" for col in args.target_columns]
|
1230 |
+
formatted_probability_strings = np.apply_along_axis(
|
1231 |
+
lambda x: ",".join(map(str, x)), 2, preds
|
1232 |
+
)
|
1233 |
+
predicted_class_labels = preds.argmax(axis=-1)
|
1234 |
+
df_preds = pd.DataFrame(
|
1235 |
+
list(zip(*namess, *predicted_class_labels.T, *formatted_probability_strings.T)),
|
1236 |
+
columns=columns,
|
1237 |
+
)
|
1238 |
+
else:
|
1239 |
+
df_preds = pd.DataFrame(list(zip(*namess, *preds.T)), columns=columns)
|
1240 |
+
df_preds.to_csv(model_output_dir / "test_predictions.csv", index=False)
|
1241 |
+
|
1242 |
+
|
1243 |
+
def main(args):
|
1244 |
+
format_kwargs = dict(
|
1245 |
+
no_header_row=args.no_header_row,
|
1246 |
+
smiles_cols=args.smiles_columns,
|
1247 |
+
rxn_cols=args.reaction_columns,
|
1248 |
+
target_cols=args.target_columns,
|
1249 |
+
ignore_cols=args.ignore_columns,
|
1250 |
+
splits_col=args.splits_column,
|
1251 |
+
weight_col=args.weight_column,
|
1252 |
+
bounded=args.loss_function is not None and "bounded" in args.loss_function,
|
1253 |
+
)
|
1254 |
+
|
1255 |
+
featurization_kwargs = dict(
|
1256 |
+
molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
|
1257 |
+
)
|
1258 |
+
|
1259 |
+
splits = build_splits(args, format_kwargs, featurization_kwargs)
|
1260 |
+
|
1261 |
+
for replicate_idx, (train_data, val_data, test_data) in enumerate(zip(*splits)):
|
1262 |
+
if args.num_replicates == 1:
|
1263 |
+
output_dir = args.output_dir
|
1264 |
+
else:
|
1265 |
+
output_dir = args.output_dir / f"replicate_{replicate_idx}"
|
1266 |
+
|
1267 |
+
output_dir.mkdir(exist_ok=True, parents=True)
|
1268 |
+
|
1269 |
+
train_dset, val_dset, test_dset = build_datasets(args, train_data, val_data, test_data)
|
1270 |
+
|
1271 |
+
if args.save_smiles_splits:
|
1272 |
+
save_smiles_splits(args, output_dir, train_dset, val_dset, test_dset)
|
1273 |
+
|
1274 |
+
if args.checkpoint or args.model_frzn is not None:
|
1275 |
+
model_paths = find_models(args.checkpoint)
|
1276 |
+
if len(model_paths) > 1:
|
1277 |
+
logger.warning(
|
1278 |
+
"Multiple checkpoint files were loaded, but only the scalers from "
|
1279 |
+
f"{model_paths[0]} are used. It is assumed that all models provided have the "
|
1280 |
+
"same data scalings, meaning they were trained on the same data."
|
1281 |
+
)
|
1282 |
+
model_path = model_paths[0] if args.checkpoint else args.model_frzn
|
1283 |
+
load_and_use_pretrained_model_scalers(model_path, train_dset, val_dset)
|
1284 |
+
input_transforms = (None, None, None)
|
1285 |
+
output_transform = None
|
1286 |
+
else:
|
1287 |
+
input_transforms = normalize_inputs(train_dset, val_dset, args)
|
1288 |
+
|
1289 |
+
if "regression" in args.task_type:
|
1290 |
+
output_scaler = train_dset.normalize_targets()
|
1291 |
+
val_dset.normalize_targets(output_scaler)
|
1292 |
+
logger.info(
|
1293 |
+
f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}"
|
1294 |
+
)
|
1295 |
+
output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
|
1296 |
+
else:
|
1297 |
+
output_transform = None
|
1298 |
+
|
1299 |
+
if not args.no_cache:
|
1300 |
+
train_dset.cache = True
|
1301 |
+
val_dset.cache = True
|
1302 |
+
|
1303 |
+
train_loader = build_dataloader(
|
1304 |
+
train_dset,
|
1305 |
+
args.batch_size,
|
1306 |
+
args.num_workers,
|
1307 |
+
class_balance=args.class_balance,
|
1308 |
+
seed=args.data_seed,
|
1309 |
+
)
|
1310 |
+
if args.class_balance:
|
1311 |
+
logger.debug(
|
1312 |
+
f"With `--class-balance`, effective train size = {len(train_loader.sampler)}"
|
1313 |
+
)
|
1314 |
+
val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
|
1315 |
+
if test_dset is not None:
|
1316 |
+
test_loader = build_dataloader(
|
1317 |
+
test_dset, args.batch_size, args.num_workers, shuffle=False
|
1318 |
+
)
|
1319 |
+
else:
|
1320 |
+
test_loader = None
|
1321 |
+
|
1322 |
+
train_model(
|
1323 |
+
args,
|
1324 |
+
train_loader,
|
1325 |
+
val_loader,
|
1326 |
+
test_loader,
|
1327 |
+
output_dir,
|
1328 |
+
output_transform,
|
1329 |
+
input_transforms,
|
1330 |
+
)
|
1331 |
+
|
1332 |
+
|
1333 |
+
if __name__ == "__main__":
|
1334 |
+
# TODO: update this old code or remove it.
|
1335 |
+
parser = ArgumentParser()
|
1336 |
+
parser = TrainSubcommand.add_args(parser)
|
1337 |
+
|
1338 |
+
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
|
1339 |
+
args = parser.parse_args()
|
1340 |
+
TrainSubcommand.func(args)
|
chemprop-updated/chemprop/cli/utils/__init__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .actions import LookupAction
|
2 |
+
from .args import bounded
|
3 |
+
from .command import Subcommand
|
4 |
+
from .parsing import (
|
5 |
+
build_data_from_files,
|
6 |
+
get_column_names,
|
7 |
+
make_datapoints,
|
8 |
+
make_dataset,
|
9 |
+
parse_indices,
|
10 |
+
)
|
11 |
+
from .utils import _pop_attr, _pop_attr_d, pop_attr
|
12 |
+
|
13 |
+
__all__ = [
|
14 |
+
"bounded",
|
15 |
+
"LookupAction",
|
16 |
+
"Subcommand",
|
17 |
+
"build_data_from_files",
|
18 |
+
"make_datapoints",
|
19 |
+
"make_dataset",
|
20 |
+
"get_column_names",
|
21 |
+
"parse_indices",
|
22 |
+
"actions",
|
23 |
+
"args",
|
24 |
+
"command",
|
25 |
+
"parsing",
|
26 |
+
"utils",
|
27 |
+
"pop_attr",
|
28 |
+
"_pop_attr",
|
29 |
+
"_pop_attr_d",
|
30 |
+
]
|
chemprop-updated/chemprop/cli/utils/actions.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from argparse import _StoreAction
|
2 |
+
from typing import Any, Mapping
|
3 |
+
|
4 |
+
|
5 |
+
def LookupAction(obj: Mapping[str, Any]):
|
6 |
+
class LookupAction_(_StoreAction):
|
7 |
+
def __init__(self, option_strings, dest, default=None, choices=None, **kwargs):
|
8 |
+
if default not in obj.keys() and default is not None:
|
9 |
+
raise ValueError(
|
10 |
+
f"Invalid value for arg 'default': '{default}'. "
|
11 |
+
f"Expected one of {tuple(obj.keys())}"
|
12 |
+
)
|
13 |
+
|
14 |
+
kwargs["choices"] = choices if choices is not None else obj.keys()
|
15 |
+
kwargs["default"] = default
|
16 |
+
|
17 |
+
super().__init__(option_strings, dest, **kwargs)
|
18 |
+
|
19 |
+
return LookupAction_
|
chemprop-updated/chemprop/cli/utils/args.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
|
3 |
+
__all__ = ["bounded"]
|
4 |
+
|
5 |
+
|
6 |
+
def bounded(lo: float | None = None, hi: float | None = None):
|
7 |
+
if lo is None and hi is None:
|
8 |
+
raise ValueError("No bounds provided!")
|
9 |
+
|
10 |
+
def decorator(f):
|
11 |
+
@functools.wraps(f)
|
12 |
+
def wrapper(*args, **kwargs):
|
13 |
+
x = f(*args, **kwargs)
|
14 |
+
|
15 |
+
if (lo is not None and hi is not None) and not lo <= x <= hi:
|
16 |
+
raise ValueError(f"Parsed value outside of range [{lo}, {hi}]! got: {x}")
|
17 |
+
if hi is not None and x > hi:
|
18 |
+
raise ValueError(f"Parsed value below {hi}! got: {x}")
|
19 |
+
if lo is not None and x < lo:
|
20 |
+
raise ValueError(f"Parsed value above {lo}]! got: {x}")
|
21 |
+
|
22 |
+
return x
|
23 |
+
|
24 |
+
return wrapper
|
25 |
+
|
26 |
+
return decorator
|
27 |
+
|
28 |
+
|
29 |
+
def uppercase(x: str):
|
30 |
+
return x.upper()
|
31 |
+
|
32 |
+
|
33 |
+
def lowercase(x: str):
|
34 |
+
return x.lower()
|
chemprop-updated/chemprop/cli/utils/command.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from argparse import ArgumentParser, Namespace, _SubParsersAction
|
3 |
+
|
4 |
+
|
5 |
+
class Subcommand(ABC):
|
6 |
+
COMMAND: str
|
7 |
+
HELP: str | None = None
|
8 |
+
|
9 |
+
@classmethod
|
10 |
+
def add(cls, subparsers: _SubParsersAction, parents) -> ArgumentParser:
|
11 |
+
parser = subparsers.add_parser(cls.COMMAND, help=cls.HELP, parents=parents)
|
12 |
+
cls.add_args(parser).set_defaults(func=cls.func)
|
13 |
+
|
14 |
+
return parser
|
15 |
+
|
16 |
+
@classmethod
|
17 |
+
@abstractmethod
|
18 |
+
def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
|
19 |
+
pass
|
20 |
+
|
21 |
+
@classmethod
|
22 |
+
@abstractmethod
|
23 |
+
def func(cls, args: Namespace):
|
24 |
+
pass
|
chemprop-updated/chemprop/cli/utils/parsing.py
ADDED
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from os import PathLike
|
3 |
+
from typing import Literal, Mapping, Sequence
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
|
9 |
+
from chemprop.data.datasets import MoleculeDataset, ReactionDataset
|
10 |
+
from chemprop.featurizers.atom import get_multi_hot_atom_featurizer
|
11 |
+
from chemprop.featurizers.bond import MultiHotBondFeaturizer, RIGRBondFeaturizer
|
12 |
+
from chemprop.featurizers.molecule import MoleculeFeaturizerRegistry
|
13 |
+
from chemprop.featurizers.molgraph import (
|
14 |
+
CondensedGraphOfReactionFeaturizer,
|
15 |
+
SimpleMoleculeMolGraphFeaturizer,
|
16 |
+
)
|
17 |
+
from chemprop.utils import make_mol
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
def parse_csv(
|
23 |
+
path: PathLike,
|
24 |
+
smiles_cols: Sequence[str] | None,
|
25 |
+
rxn_cols: Sequence[str] | None,
|
26 |
+
target_cols: Sequence[str] | None,
|
27 |
+
ignore_cols: Sequence[str] | None,
|
28 |
+
splits_col: str | None,
|
29 |
+
weight_col: str | None,
|
30 |
+
bounded: bool = False,
|
31 |
+
no_header_row: bool = False,
|
32 |
+
):
|
33 |
+
df = pd.read_csv(path, header=None if no_header_row else "infer", index_col=False)
|
34 |
+
|
35 |
+
if smiles_cols is not None and rxn_cols is not None:
|
36 |
+
smiss = df[smiles_cols].T.values.tolist()
|
37 |
+
rxnss = df[rxn_cols].T.values.tolist()
|
38 |
+
input_cols = [*smiles_cols, *rxn_cols]
|
39 |
+
elif smiles_cols is not None and rxn_cols is None:
|
40 |
+
smiss = df[smiles_cols].T.values.tolist()
|
41 |
+
rxnss = None
|
42 |
+
input_cols = smiles_cols
|
43 |
+
elif smiles_cols is None and rxn_cols is not None:
|
44 |
+
smiss = None
|
45 |
+
rxnss = df[rxn_cols].T.values.tolist()
|
46 |
+
input_cols = rxn_cols
|
47 |
+
else:
|
48 |
+
smiss = df.iloc[:, [0]].T.values.tolist()
|
49 |
+
rxnss = None
|
50 |
+
input_cols = [df.columns[0]]
|
51 |
+
|
52 |
+
if target_cols is None:
|
53 |
+
target_cols = list(
|
54 |
+
column
|
55 |
+
for column in df.columns
|
56 |
+
if column
|
57 |
+
not in set( # if splits or weight is None, df.columns will never have None
|
58 |
+
input_cols + (ignore_cols or []) + [splits_col] + [weight_col]
|
59 |
+
)
|
60 |
+
)
|
61 |
+
|
62 |
+
Y = df[target_cols]
|
63 |
+
weights = None if weight_col is None else df[weight_col].to_numpy(np.single)
|
64 |
+
|
65 |
+
if bounded:
|
66 |
+
lt_mask = Y.applymap(lambda x: "<" in x).to_numpy()
|
67 |
+
gt_mask = Y.applymap(lambda x: ">" in x).to_numpy()
|
68 |
+
Y = Y.applymap(lambda x: x.strip("<").strip(">")).to_numpy(np.single)
|
69 |
+
else:
|
70 |
+
Y = Y.to_numpy(np.single)
|
71 |
+
lt_mask = None
|
72 |
+
gt_mask = None
|
73 |
+
|
74 |
+
return smiss, rxnss, Y, weights, lt_mask, gt_mask
|
75 |
+
|
76 |
+
|
77 |
+
def get_column_names(
|
78 |
+
path: PathLike,
|
79 |
+
smiles_cols: Sequence[str] | None,
|
80 |
+
rxn_cols: Sequence[str] | None,
|
81 |
+
target_cols: Sequence[str] | None,
|
82 |
+
ignore_cols: Sequence[str] | None,
|
83 |
+
splits_col: str | None,
|
84 |
+
weight_col: str | None,
|
85 |
+
no_header_row: bool = False,
|
86 |
+
) -> tuple[list[str], list[str]]:
|
87 |
+
df_cols = pd.read_csv(path, index_col=False, nrows=0).columns.tolist()
|
88 |
+
|
89 |
+
if no_header_row:
|
90 |
+
return ["SMILES"], ["pred_" + str(i) for i in range((len(df_cols) - 1))]
|
91 |
+
|
92 |
+
input_cols = (smiles_cols or []) + (rxn_cols or [])
|
93 |
+
|
94 |
+
if len(input_cols) == 0:
|
95 |
+
input_cols = [df_cols[0]]
|
96 |
+
|
97 |
+
if target_cols is None:
|
98 |
+
target_cols = list(
|
99 |
+
column
|
100 |
+
for column in df_cols
|
101 |
+
if column
|
102 |
+
not in set(
|
103 |
+
input_cols + (ignore_cols or []) + ([splits_col] or []) + ([weight_col] or [])
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
return input_cols, target_cols
|
108 |
+
|
109 |
+
|
110 |
+
def make_datapoints(
|
111 |
+
smiss: list[list[str]] | None,
|
112 |
+
rxnss: list[list[str]] | None,
|
113 |
+
Y: np.ndarray,
|
114 |
+
weights: np.ndarray | None,
|
115 |
+
lt_mask: np.ndarray | None,
|
116 |
+
gt_mask: np.ndarray | None,
|
117 |
+
X_d: np.ndarray | None,
|
118 |
+
V_fss: list[list[np.ndarray] | list[None]] | None,
|
119 |
+
E_fss: list[list[np.ndarray] | list[None]] | None,
|
120 |
+
V_dss: list[list[np.ndarray] | list[None]] | None,
|
121 |
+
molecule_featurizers: list[str] | None,
|
122 |
+
keep_h: bool,
|
123 |
+
add_h: bool,
|
124 |
+
) -> tuple[list[list[MoleculeDatapoint]], list[list[ReactionDatapoint]]]:
|
125 |
+
"""Make the :class:`MoleculeDatapoint`s and :class:`ReactionDatapoint`s for a given
|
126 |
+
dataset.
|
127 |
+
|
128 |
+
Parameters
|
129 |
+
----------
|
130 |
+
smiss : list[list[str]] | None
|
131 |
+
a list of ``j`` lists of ``n`` SMILES strings, where ``j`` is the number of molecules per
|
132 |
+
datapoint and ``n`` is the number of datapoints. If ``None``, the corresponding list of
|
133 |
+
:class:`MoleculeDatapoint`\s will be empty.
|
134 |
+
rxnss : list[list[str]] | None
|
135 |
+
a list of ``k`` lists of ``n`` reaction SMILES strings, where ``k`` is the number of
|
136 |
+
reactions per datapoint. If ``None``, the corresponding list of :class:`ReactionDatapoint`\s
|
137 |
+
will be empty.
|
138 |
+
Y : np.ndarray
|
139 |
+
the target values of shape ``n x m``, where ``m`` is the number of targets
|
140 |
+
weights : np.ndarray | None
|
141 |
+
the weights of the datapoints to use in the loss function of shape ``n x m``. If ``None``,
|
142 |
+
the weights all default to 1.
|
143 |
+
lt_mask : np.ndarray | None
|
144 |
+
a boolean mask of shape ``n x m`` indicating whether the targets are less than inequality
|
145 |
+
targets. If ``None``, ``lt_mask`` for all datapoints will be ``None``.
|
146 |
+
gt_mask : np.ndarray | None
|
147 |
+
a boolean mask of shape ``n x m`` indicating whether the targets are greater than inequality
|
148 |
+
targets. If ``None``, ``gt_mask`` for all datapoints will be ``None``.
|
149 |
+
X_d : np.ndarray | None
|
150 |
+
the extra descriptors of shape ``n x p``, where ``p`` is the number of extra descriptors. If
|
151 |
+
``None``, ``x_d`` for all datapoints will be ``None``.
|
152 |
+
V_fss : list[list[np.ndarray] | list[None]] | None
|
153 |
+
a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x q_j``, where ``v_jn`` is
|
154 |
+
the number of atoms in the j-th molecule of the n-th datapoint and ``q_j`` is the number of
|
155 |
+
extra atom features used for the j-th molecules. Any of the ``j`` lists can be a list of
|
156 |
+
None values if the corresponding component does not use extra atom features. If ``None``,
|
157 |
+
``V_f`` for all datapoints will be ``None``.
|
158 |
+
E_fss : list[list[np.ndarray] | list[None]] | None
|
159 |
+
a list of ``j`` lists of ``n`` np.ndarrays each of shape ``e_jn x r_j``, where ``e_jn`` is
|
160 |
+
the number of bonds in the j-th molecule of the n-th datapoint and ``r_j`` is the number of
|
161 |
+
extra bond features used for the j-th molecules. Any of the ``j`` lists can be a list of
|
162 |
+
None values if the corresponding component does not use extra bond features. If ``None``,
|
163 |
+
``E_f`` for all datapoints will be ``None``.
|
164 |
+
V_dss : list[list[np.ndarray] | list[None]] | None
|
165 |
+
a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x s_j``, where ``s_j`` is
|
166 |
+
the number of extra atom descriptors used for the j-th molecules. Any of the ``j`` lists can
|
167 |
+
be a list of None values if the corresponding component does not use extra atom features. If
|
168 |
+
``None``, ``V_d`` for all datapoints will be ``None``.
|
169 |
+
molecule_featurizers : list[str] | None
|
170 |
+
a list of molecule featurizer names to generate additional molecule features to use as extra
|
171 |
+
descriptors. If there are multiple molecules per datapoint, the featurizers will be applied
|
172 |
+
to each molecule and concatenated. Note that a :code:`ReactionDatapoint` has two
|
173 |
+
RDKit :class:`~rdkit.Chem.Mol` objects, reactant(s) and product(s). Each
|
174 |
+
``molecule_featurizer`` will be applied to both of these objects.
|
175 |
+
keep_h : bool
|
176 |
+
add_h : bool
|
177 |
+
|
178 |
+
Returns
|
179 |
+
-------
|
180 |
+
list[list[MoleculeDatapoint]]
|
181 |
+
a list of ``j`` lists of ``n`` :class:`MoleculeDatapoint`\s
|
182 |
+
list[list[ReactionDatapoint]]
|
183 |
+
a list of ``k`` lists of ``n`` :class:`ReactionDatapoint`\s
|
184 |
+
.. note::
|
185 |
+
either ``j`` or ``k`` may be 0, in which case the corresponding list will be empty.
|
186 |
+
|
187 |
+
Raises
|
188 |
+
------
|
189 |
+
ValueError
|
190 |
+
if both ``smiss`` and ``rxnss`` are ``None``.
|
191 |
+
if ``smiss`` and ``rxnss`` are both given and have different lengths.
|
192 |
+
"""
|
193 |
+
if smiss is None and rxnss is None:
|
194 |
+
raise ValueError("args 'smiss' and 'rnxss' were both `None`!")
|
195 |
+
elif rxnss is None:
|
196 |
+
N = len(smiss[0])
|
197 |
+
rxnss = []
|
198 |
+
elif smiss is None:
|
199 |
+
N = len(rxnss[0])
|
200 |
+
smiss = []
|
201 |
+
elif len(smiss[0]) != len(rxnss[0]):
|
202 |
+
raise ValueError(
|
203 |
+
f"args 'smiss' and 'rxnss' must have same length! got {len(smiss[0])} and {len(rxnss[0])}"
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
N = len(smiss[0])
|
207 |
+
|
208 |
+
if len(smiss) > 0:
|
209 |
+
molss = [[make_mol(smi, keep_h, add_h) for smi in smis] for smis in smiss]
|
210 |
+
if len(rxnss) > 0:
|
211 |
+
rctss = [
|
212 |
+
[
|
213 |
+
make_mol(f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi, keep_h, add_h)
|
214 |
+
for rct_smi, agt_smi, _ in (rxn.split(">") for rxn in rxns)
|
215 |
+
]
|
216 |
+
for rxns in rxnss
|
217 |
+
]
|
218 |
+
pdtss = [
|
219 |
+
[make_mol(pdt_smi, keep_h, add_h) for _, _, pdt_smi in (rxn.split(">") for rxn in rxns)]
|
220 |
+
for rxns in rxnss
|
221 |
+
]
|
222 |
+
|
223 |
+
weights = np.ones(N, dtype=np.single) if weights is None else weights
|
224 |
+
gt_mask = [None] * N if gt_mask is None else gt_mask
|
225 |
+
lt_mask = [None] * N if lt_mask is None else lt_mask
|
226 |
+
|
227 |
+
n_mols = len(smiss) if smiss else 0
|
228 |
+
V_fss = [[None] * N] * n_mols if V_fss is None else V_fss
|
229 |
+
E_fss = [[None] * N] * n_mols if E_fss is None else E_fss
|
230 |
+
V_dss = [[None] * N] * n_mols if V_dss is None else V_dss
|
231 |
+
|
232 |
+
if X_d is None and molecule_featurizers is None:
|
233 |
+
X_d = [None] * N
|
234 |
+
elif molecule_featurizers is None:
|
235 |
+
pass
|
236 |
+
else:
|
237 |
+
molecule_featurizers = [MoleculeFeaturizerRegistry[mf]() for mf in molecule_featurizers]
|
238 |
+
|
239 |
+
if len(smiss) > 0:
|
240 |
+
mol_descriptors = np.hstack(
|
241 |
+
[
|
242 |
+
np.vstack([np.hstack([mf(mol) for mf in molecule_featurizers]) for mol in mols])
|
243 |
+
for mols in molss
|
244 |
+
]
|
245 |
+
)
|
246 |
+
if X_d is None:
|
247 |
+
X_d = mol_descriptors
|
248 |
+
else:
|
249 |
+
X_d = np.hstack([X_d, mol_descriptors])
|
250 |
+
|
251 |
+
if len(rxnss) > 0:
|
252 |
+
rct_pdt_descriptors = np.hstack(
|
253 |
+
[
|
254 |
+
np.vstack(
|
255 |
+
[
|
256 |
+
np.hstack(
|
257 |
+
[mf(mol) for mf in molecule_featurizers for mol in (rct, pdt)]
|
258 |
+
)
|
259 |
+
for rct, pdt in zip(rcts, pdts)
|
260 |
+
]
|
261 |
+
)
|
262 |
+
for rcts, pdts in zip(rctss, pdtss)
|
263 |
+
]
|
264 |
+
)
|
265 |
+
if X_d is None:
|
266 |
+
X_d = rct_pdt_descriptors
|
267 |
+
else:
|
268 |
+
X_d = np.hstack([X_d, rct_pdt_descriptors])
|
269 |
+
|
270 |
+
mol_data = [
|
271 |
+
[
|
272 |
+
MoleculeDatapoint(
|
273 |
+
mol=molss[mol_idx][i],
|
274 |
+
name=smis[i],
|
275 |
+
y=Y[i],
|
276 |
+
weight=weights[i],
|
277 |
+
gt_mask=gt_mask[i],
|
278 |
+
lt_mask=lt_mask[i],
|
279 |
+
x_d=X_d[i],
|
280 |
+
x_phase=None,
|
281 |
+
V_f=V_fss[mol_idx][i],
|
282 |
+
E_f=E_fss[mol_idx][i],
|
283 |
+
V_d=V_dss[mol_idx][i],
|
284 |
+
)
|
285 |
+
for i in range(N)
|
286 |
+
]
|
287 |
+
for mol_idx, smis in enumerate(smiss)
|
288 |
+
]
|
289 |
+
rxn_data = [
|
290 |
+
[
|
291 |
+
ReactionDatapoint(
|
292 |
+
rct=rctss[rxn_idx][i],
|
293 |
+
pdt=pdtss[rxn_idx][i],
|
294 |
+
name=rxns[i],
|
295 |
+
y=Y[i],
|
296 |
+
weight=weights[i],
|
297 |
+
gt_mask=gt_mask[i],
|
298 |
+
lt_mask=lt_mask[i],
|
299 |
+
x_d=X_d[i],
|
300 |
+
x_phase=None,
|
301 |
+
)
|
302 |
+
for i in range(N)
|
303 |
+
]
|
304 |
+
for rxn_idx, rxns in enumerate(rxnss)
|
305 |
+
]
|
306 |
+
|
307 |
+
return mol_data, rxn_data
|
308 |
+
|
309 |
+
|
310 |
+
def build_data_from_files(
|
311 |
+
p_data: PathLike,
|
312 |
+
no_header_row: bool,
|
313 |
+
smiles_cols: Sequence[str] | None,
|
314 |
+
rxn_cols: Sequence[str] | None,
|
315 |
+
target_cols: Sequence[str] | None,
|
316 |
+
ignore_cols: Sequence[str] | None,
|
317 |
+
splits_col: str | None,
|
318 |
+
weight_col: str | None,
|
319 |
+
bounded: bool,
|
320 |
+
p_descriptors: PathLike,
|
321 |
+
p_atom_feats: dict[int, PathLike],
|
322 |
+
p_bond_feats: dict[int, PathLike],
|
323 |
+
p_atom_descs: dict[int, PathLike],
|
324 |
+
**featurization_kwargs: Mapping,
|
325 |
+
) -> list[list[MoleculeDatapoint] | list[ReactionDatapoint]]:
|
326 |
+
smiss, rxnss, Y, weights, lt_mask, gt_mask = parse_csv(
|
327 |
+
p_data,
|
328 |
+
smiles_cols,
|
329 |
+
rxn_cols,
|
330 |
+
target_cols,
|
331 |
+
ignore_cols,
|
332 |
+
splits_col,
|
333 |
+
weight_col,
|
334 |
+
bounded,
|
335 |
+
no_header_row,
|
336 |
+
)
|
337 |
+
n_molecules = len(smiss) if smiss is not None else 0
|
338 |
+
n_datapoints = len(Y)
|
339 |
+
|
340 |
+
X_ds = load_input_feats_and_descs(p_descriptors, None, None, feat_desc="X_d")
|
341 |
+
V_fss = load_input_feats_and_descs(p_atom_feats, n_molecules, n_datapoints, feat_desc="V_f")
|
342 |
+
E_fss = load_input_feats_and_descs(p_bond_feats, n_molecules, n_datapoints, feat_desc="E_f")
|
343 |
+
V_dss = load_input_feats_and_descs(p_atom_descs, n_molecules, n_datapoints, feat_desc="V_d")
|
344 |
+
|
345 |
+
mol_data, rxn_data = make_datapoints(
|
346 |
+
smiss,
|
347 |
+
rxnss,
|
348 |
+
Y,
|
349 |
+
weights,
|
350 |
+
lt_mask,
|
351 |
+
gt_mask,
|
352 |
+
X_ds,
|
353 |
+
V_fss,
|
354 |
+
E_fss,
|
355 |
+
V_dss,
|
356 |
+
**featurization_kwargs,
|
357 |
+
)
|
358 |
+
|
359 |
+
return mol_data + rxn_data
|
360 |
+
|
361 |
+
|
362 |
+
def load_input_feats_and_descs(
|
363 |
+
paths: dict[int, PathLike] | PathLike,
|
364 |
+
n_molecules: int | None,
|
365 |
+
n_datapoints: int | None,
|
366 |
+
feat_desc: str,
|
367 |
+
):
|
368 |
+
if paths is None:
|
369 |
+
return None
|
370 |
+
|
371 |
+
match feat_desc:
|
372 |
+
case "X_d":
|
373 |
+
path = paths
|
374 |
+
loaded_feature = np.load(path)
|
375 |
+
features = loaded_feature["arr_0"]
|
376 |
+
|
377 |
+
case _:
|
378 |
+
for index in paths:
|
379 |
+
if index >= n_molecules:
|
380 |
+
raise ValueError(
|
381 |
+
f"For {n_molecules} molecules, atom/bond features/descriptors can only be "
|
382 |
+
f"specified for indices 0-{n_molecules - 1}! Got index {index}."
|
383 |
+
)
|
384 |
+
|
385 |
+
features = []
|
386 |
+
for idx in range(n_molecules):
|
387 |
+
path = paths.get(idx, None)
|
388 |
+
|
389 |
+
if path is not None:
|
390 |
+
loaded_feature = np.load(path)
|
391 |
+
loaded_feature = [
|
392 |
+
loaded_feature[f"arr_{i}"] for i in range(len(loaded_feature))
|
393 |
+
]
|
394 |
+
else:
|
395 |
+
loaded_feature = [None] * n_datapoints
|
396 |
+
|
397 |
+
features.append(loaded_feature)
|
398 |
+
return features
|
399 |
+
|
400 |
+
|
401 |
+
def make_dataset(
|
402 |
+
data: Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint],
|
403 |
+
reaction_mode: str,
|
404 |
+
multi_hot_atom_featurizer_mode: Literal["V1", "V2", "ORGANIC", "RIGR"] = "V2",
|
405 |
+
) -> MoleculeDataset | ReactionDataset:
|
406 |
+
atom_featurizer = get_multi_hot_atom_featurizer(multi_hot_atom_featurizer_mode)
|
407 |
+
match multi_hot_atom_featurizer_mode:
|
408 |
+
case "RIGR":
|
409 |
+
bond_featurizer = RIGRBondFeaturizer()
|
410 |
+
case "V1" | "V2" | "ORGANIC":
|
411 |
+
bond_featurizer = MultiHotBondFeaturizer()
|
412 |
+
case _:
|
413 |
+
raise TypeError(
|
414 |
+
f"Unsupported atom featurizer mode '{multi_hot_atom_featurizer_mode=}'!"
|
415 |
+
)
|
416 |
+
|
417 |
+
if isinstance(data[0], MoleculeDatapoint):
|
418 |
+
extra_atom_fdim = data[0].V_f.shape[1] if data[0].V_f is not None else 0
|
419 |
+
extra_bond_fdim = data[0].E_f.shape[1] if data[0].E_f is not None else 0
|
420 |
+
featurizer = SimpleMoleculeMolGraphFeaturizer(
|
421 |
+
atom_featurizer=atom_featurizer,
|
422 |
+
bond_featurizer=bond_featurizer,
|
423 |
+
extra_atom_fdim=extra_atom_fdim,
|
424 |
+
extra_bond_fdim=extra_bond_fdim,
|
425 |
+
)
|
426 |
+
return MoleculeDataset(data, featurizer)
|
427 |
+
|
428 |
+
featurizer = CondensedGraphOfReactionFeaturizer(
|
429 |
+
mode_=reaction_mode, atom_featurizer=atom_featurizer
|
430 |
+
)
|
431 |
+
|
432 |
+
return ReactionDataset(data, featurizer)
|
433 |
+
|
434 |
+
|
435 |
+
def parse_indices(idxs):
|
436 |
+
"""Parses a string of indices into a list of integers. e.g. '0,1,2-4' -> [0, 1, 2, 3, 4]"""
|
437 |
+
if isinstance(idxs, str):
|
438 |
+
indices = []
|
439 |
+
for idx in idxs.split(","):
|
440 |
+
if "-" in idx:
|
441 |
+
start, end = map(int, idx.split("-"))
|
442 |
+
indices.extend(range(start, end + 1))
|
443 |
+
else:
|
444 |
+
indices.append(int(idx))
|
445 |
+
return indices
|
446 |
+
return idxs
|
chemprop-updated/chemprop/cli/utils/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
__all__ = ["pop_attr"]
|
4 |
+
|
5 |
+
|
6 |
+
def pop_attr(o: object, attr: str, *args) -> Any | None:
|
7 |
+
"""like ``pop()`` but for attribute maps"""
|
8 |
+
match len(args):
|
9 |
+
case 0:
|
10 |
+
return _pop_attr(o, attr)
|
11 |
+
case 1:
|
12 |
+
return _pop_attr_d(o, attr, args[0])
|
13 |
+
case _:
|
14 |
+
raise TypeError(f"Expected at most 2 arguments! got: {len(args)}")
|
15 |
+
|
16 |
+
|
17 |
+
def _pop_attr(o: object, attr: str) -> Any:
|
18 |
+
val = getattr(o, attr)
|
19 |
+
delattr(o, attr)
|
20 |
+
|
21 |
+
return val
|
22 |
+
|
23 |
+
|
24 |
+
def _pop_attr_d(o: object, attr: str, default: Any | None = None) -> Any | None:
|
25 |
+
try:
|
26 |
+
val = getattr(o, attr)
|
27 |
+
delattr(o, attr)
|
28 |
+
except AttributeError:
|
29 |
+
val = default
|
30 |
+
|
31 |
+
return val
|
chemprop-updated/chemprop/conf.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Global configuration variables for chemprop"""
|
2 |
+
|
3 |
+
from chemprop.featurizers.molgraph.molecule import SimpleMoleculeMolGraphFeaturizer
|
4 |
+
|
5 |
+
DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM = SimpleMoleculeMolGraphFeaturizer().shape
|
6 |
+
DEFAULT_HIDDEN_DIM = 300
|