hbhzm commited on
Commit
3ea26d1
·
verified ·
1 Parent(s): 8037081

Upload 625 files

Browse files

add model and bbbp model weight

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. chemprop-updated/.bumpversion.cfg +10 -0
  3. chemprop-updated/.dockerignore +3 -0
  4. chemprop-updated/.flake8 +9 -0
  5. chemprop-updated/.github/ISSUE_TEMPLATE/todo.md +11 -0
  6. chemprop-updated/.github/ISSUE_TEMPLATE/v1_bug_report.md +35 -0
  7. chemprop-updated/.github/ISSUE_TEMPLATE/v1_question.md +17 -0
  8. chemprop-updated/.github/ISSUE_TEMPLATE/v2_bug_report.md +35 -0
  9. chemprop-updated/.github/ISSUE_TEMPLATE/v2_feature_request.md +23 -0
  10. chemprop-updated/.github/ISSUE_TEMPLATE/v2_question.md +17 -0
  11. chemprop-updated/.github/PULL_REQUEST_TEMPLATE.md +18 -0
  12. chemprop-updated/.github/PULL_REQUEST_TEMPLATE/bugfix.md +12 -0
  13. chemprop-updated/.github/PULL_REQUEST_TEMPLATE/new_feature.md +15 -0
  14. chemprop-updated/.github/workflows/ci.yml +158 -0
  15. chemprop-updated/.gitignore +178 -0
  16. chemprop-updated/.readthedocs.yml +19 -0
  17. chemprop-updated/CITATIONS.bib +37 -0
  18. chemprop-updated/CONTRIBUTING.md +40 -0
  19. chemprop-updated/Dockerfile +50 -0
  20. chemprop-updated/LICENSE.txt +27 -0
  21. chemprop-updated/README.md +63 -0
  22. chemprop-updated/chemprop/__init__.py +5 -0
  23. chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc +0 -0
  24. chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc +0 -0
  25. chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc +0 -0
  26. chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc +0 -0
  27. chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc +0 -0
  28. chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc +0 -0
  29. chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc +0 -0
  30. chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc +0 -0
  31. chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc +0 -0
  32. chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc +0 -0
  33. chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc +0 -0
  34. chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc +0 -0
  35. chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc +0 -0
  36. chemprop-updated/chemprop/cli/common.py +211 -0
  37. chemprop-updated/chemprop/cli/conf.py +9 -0
  38. chemprop-updated/chemprop/cli/convert.py +55 -0
  39. chemprop-updated/chemprop/cli/fingerprint.py +182 -0
  40. chemprop-updated/chemprop/cli/hpopt.py +537 -0
  41. chemprop-updated/chemprop/cli/main.py +85 -0
  42. chemprop-updated/chemprop/cli/predict.py +444 -0
  43. chemprop-updated/chemprop/cli/train.py +1340 -0
  44. chemprop-updated/chemprop/cli/utils/__init__.py +30 -0
  45. chemprop-updated/chemprop/cli/utils/actions.py +19 -0
  46. chemprop-updated/chemprop/cli/utils/args.py +34 -0
  47. chemprop-updated/chemprop/cli/utils/command.py +24 -0
  48. chemprop-updated/chemprop/cli/utils/parsing.py +446 -0
  49. chemprop-updated/chemprop/cli/utils/utils.py +31 -0
  50. chemprop-updated/chemprop/conf.py +6 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ chemprop-updated/docs/source/_static/images/message_passing.png filter=lfs diff=lfs merge=lfs -text
37
+ chemprop/docs/source/_static/images/message_passing.png filter=lfs diff=lfs merge=lfs -text
chemprop-updated/.bumpversion.cfg ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ [bumpversion]
2
+ current_version = 2.1.2
3
+ commit = True
4
+ tag = True
5
+
6
+ [bumpversion:file:pyproject.toml]
7
+
8
+ [bumpversion:file:chemprop/__init__.py]
9
+
10
+ [bumpversion:file:docs/source/conf.py]
chemprop-updated/.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **.git*
2
+ .dockerignore
3
+ Dockerfile
chemprop-updated/.flake8 ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ [flake8]
2
+ ignore = E203, E266, E501, F403, E741, W503, W605
3
+ max-line-length = 100
4
+ max-complexity = 18
5
+ per-file-ignores =
6
+ __init__.py: F401
7
+ chemprop/nn/predictors.py: F405
8
+ chemprop/nn/metrics.py: F405
9
+ tests/unit/nn/test_metrics.py: E121, E122, E131, E241, W291
chemprop-updated/.github/ISSUE_TEMPLATE/todo.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: to-do
3
+ about: Add an item to the to-do list. More generic than a feature request
4
+ title: "[TODO]: "
5
+ labels: todo
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Notes**
11
+ _these could be implementation or more specific details to keep in mind, if they'll be helpful for issue tracking_
chemprop-updated/.github/ISSUE_TEMPLATE/v1_bug_report.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: v1 Bug Report
3
+ about: Report a bug in v1 (will not be fixed)
4
+ title: "[v1 BUG]: "
5
+ labels: bug, v1-wontfix
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **Example(s)**
14
+ Provide some examples of where the current code fails. Feel free to share your actual code for additional context, but a minimal and isolated example is preferred.
15
+
16
+ **Expected behavior**
17
+ A clear and concise description of what you expected to happen. If there is correct, expected output, include that here as well.
18
+
19
+ **Error Stack Trace**
20
+ If the bug is resulting in an error message, provide the _full_ stack trace (not just the last line). This is helpful for debugging, especially in cases where you aren't able to provide a minimum/isolated working example with accompanying files.
21
+
22
+ **Screenshots**
23
+ If applicable, add screenshots to help explain your problem.
24
+
25
+ **Environment**
26
+ - python version
27
+ - package versions: `conda list` or `pip list`
28
+ - OS
29
+
30
+ **Checklist**
31
+ - [ ] all dependencies are satisifed: `conda list` or `pip list` shows the packages listed in the `pyproject.toml`
32
+ - [ ] the unit tests are working: `pytest -v` reports no errors
33
+
34
+ **Additional context**
35
+ Add any other context about the problem here.
chemprop-updated/.github/ISSUE_TEMPLATE/v1_question.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: v1 Question
3
+ about: Have a question about how to use Chemprop v1?
4
+ title: "[v1 QUESTION]: "
5
+ labels: question
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **What are you trying to do?**
11
+ Please tell us what you're trying to do with Chemprop, providing as much detail as possible
12
+
13
+ **Previous attempts**
14
+ If possible, provide some examples of what you've already tried and what the output was.
15
+
16
+ **Screenshots**
17
+ If applicable, add screenshots to help explain your problem.
chemprop-updated/.github/ISSUE_TEMPLATE/v2_bug_report.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: v2 Bug Report
3
+ about: Create a report to help us improve
4
+ title: "[v2 BUG]: "
5
+ labels: bug
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **Example(s)**
14
+ Provide some examples of where the current code fails. Feel free to share your actual code for additional context, but a minimal and isolated example is preferred.
15
+
16
+ **Expected behavior**
17
+ A clear and concise description of what you expected to happen. If there is correct, expected output, include that here as well.
18
+
19
+ **Error Stack Trace**
20
+ If the bug is resulting in an error message, provide the _full_ stack trace (not just the last line). This is helpful for debugging, especially in cases where you aren't able to provide a minimum/isolated working example with accompanying files.
21
+
22
+ **Screenshots**
23
+ If applicable, add screenshots to help explain your problem.
24
+
25
+ **Environment**
26
+ - python version
27
+ - package versions: `conda list` or `pip list`
28
+ - OS
29
+
30
+ **Checklist**
31
+ - [ ] all dependencies are satisifed: `conda list` or `pip list` shows the packages listed in the `pyproject.toml`
32
+ - [ ] the unit tests are working: `pytest -v` reports no errors
33
+
34
+ **Additional context**
35
+ Add any other context about the problem here.
chemprop-updated/.github/ISSUE_TEMPLATE/v2_feature_request.md ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: v2 Feature Request
3
+ about: Suggest an idea for this project
4
+ title: "[v2 FEATURE]: "
5
+ labels: enhancement
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Is your feature request related to a problem? Please describe.**
11
+ A clear and concise description of what the problem is.
12
+
13
+ **Use-cases/examples of this new feature**
14
+ What are some example workflows that would employ this new feature? Are there any relevant issues?
15
+
16
+ **Desired solution/workflow**
17
+ A clear and concise description of what you want to happen. Include some (pseudo)code, if possible
18
+
19
+ **Discussion**
20
+ What are some considerations around this new feature? Are there alternative approaches to consider? What should the scope of the feature be?
21
+
22
+ **Additional context**
23
+ Add any other context or screenshots about the feature request here.
chemprop-updated/.github/ISSUE_TEMPLATE/v2_question.md ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: v2 Question
3
+ about: Have a question about how to use Chemprop v2?
4
+ title: "[v2 QUESTION]: "
5
+ labels: question
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **What are you trying to do?**
11
+ Please tell us what you're trying to do with Chemprop, providing as much detail as possible
12
+
13
+ **Previous attempts**
14
+ If possible, provide some examples of what you've already tried and what the output was.
15
+
16
+ **Screenshots**
17
+ If applicable, add screenshots to help explain your problem.
chemprop-updated/.github/PULL_REQUEST_TEMPLATE.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Description
2
+ Include a brief summary of the bug/feature/etc. that this PR seeks to address
3
+
4
+ ## Example / Current workflow
5
+ Include a sample workflow to either **(a)** reproduce the bug with current codebase or **(b)** showcase the deficiency does this PR seeks to address
6
+
7
+ ## Bugfix / Desired workflow
8
+ Include either **(a)** the same workflow from above with the correct output produced via this PR **(b)** some (pseudo)code containing the new workflow that this PR will (seek to) implement
9
+
10
+ ## Questions
11
+ If there are open questions about implementation strategy or scope of the PR, include them here
12
+
13
+ ## Relevant issues
14
+ If appropriate, please tag them here and include a quick summary
15
+
16
+ ## Checklist
17
+ - [ ] linted with flake8?
18
+ - [ ] (if appropriate) unit tests added?
chemprop-updated/.github/PULL_REQUEST_TEMPLATE/bugfix.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Bug report
2
+ Include a brief summary of the bug that this PR seeks to address. If possible, include relevant issue tags
3
+
4
+ ## Example
5
+ Include a sample execution to reproduce the bug with current codebase, and some sample output showcasing that the PR fixes this bug
6
+
7
+ ## Questions
8
+ If there are open questions about implementation strategy or scope of the PR, include them here
9
+
10
+ ## Checklist
11
+ - [ ] linted with flake8?
12
+ - [ ] (if necessary) appropriate unit tests added?
chemprop-updated/.github/PULL_REQUEST_TEMPLATE/new_feature.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Statement of need
2
+ What deficiency does this PR seek to address? If there are relevant issues, please tag them here
3
+
4
+ ## Current workflow
5
+ How is this need achieved with the current codebase?
6
+
7
+ ## Desired workflow
8
+ Include some (pseudo)code containing the new workflow that this PR will (seek to) implement
9
+
10
+ ## Questions
11
+ If there are open questions about implementation strategy or scope of the PR, include them here
12
+
13
+ ## Checklist
14
+ - [ ] linted with flake8?
15
+ - [ ] appropriate unit tests added?
chemprop-updated/.github/workflows/ci.yml ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ci.yml
2
+ #
3
+ # Continuous Integration for Chemprop - checks build, code formatting, and runs tests for all
4
+ # proposed changes and on a regular schedule
5
+ #
6
+ # Note: this file contains extensive inline documentation to aid with knowledge transfer.
7
+
8
+ name: Continuous Integration
9
+
10
+ on:
11
+ # run on pushes/pull requests to/against main
12
+ push:
13
+ branches: [main]
14
+ pull_request:
15
+ branches: [main]
16
+ # run this in the morning on weekdays to catch dependency issues
17
+ schedule:
18
+ - cron: "0 8 * * 1-5"
19
+ # allow manual runs
20
+ workflow_dispatch:
21
+
22
+ # cancel previously running tests if new commits are made
23
+ # https://docs.github.com/en/actions/examples/using-concurrency-expressions-and-a-test-matrix
24
+ concurrency:
25
+ group: actions-id-${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
26
+ cancel-in-progress: true
27
+
28
+ env:
29
+ USE_LIBUV: 0 # libuv doesn't work on GitHub actions Windows runner
30
+
31
+ jobs:
32
+ build:
33
+ name: Check Build
34
+ runs-on: ubuntu-latest
35
+ steps:
36
+ # clone the repo, attempt to build
37
+ - uses: actions/checkout@v4
38
+ - run: python -m pip install build
39
+ - run: python -m build .
40
+
41
+ lint:
42
+ name: Check Formatting
43
+ needs: build
44
+ runs-on: ubuntu-latest
45
+ steps:
46
+ # clone the repo, run black and flake8 on it
47
+ - uses: actions/checkout@v4
48
+ - run: python -m pip install black==23.* flake8 isort
49
+ - run: black --check .
50
+ - run: flake8 .
51
+ - run: isort --check .
52
+
53
+ test:
54
+ name: Execute Tests
55
+ needs: lint
56
+ runs-on: ${{ matrix.os }}
57
+ defaults:
58
+ run:
59
+ # run with a login shell (so that the conda environment is activated)
60
+ # and echo the commands we run as we do them (for debugging purposes)
61
+ shell: bash -el {0}
62
+ strategy:
63
+ # if one platform/python version fails, continue testing the others
64
+ fail-fast: false
65
+ matrix:
66
+ # test on all platforms with both supported versions of Python
67
+ os: [ubuntu-latest, macos-13, windows-latest]
68
+ python-version: [3.11, 3.12]
69
+ steps:
70
+ - uses: actions/checkout@v4
71
+ # use a version of the conda virtual environment manager to set up an
72
+ # isolated environment with the Python version we want
73
+ - uses: conda-incubator/setup-miniconda@v3
74
+ with:
75
+ python-version: ${{ matrix.python-version }}
76
+ auto-update-conda: true
77
+ show-channel-urls: true
78
+ conda-remove-defaults: "true"
79
+ environment-file: environment.yml
80
+ activate-environment: chemprop
81
+ - name: Install dependencies
82
+ shell: bash -l {0}
83
+ run: |
84
+ python -m pip install nbmake
85
+ python -m pip install ".[dev,docs,test,hpopt]"
86
+ - name: Test with pytest
87
+ shell: bash -l {0}
88
+ run: |
89
+ pytest -v tests
90
+ - name: Test notebooks
91
+ shell: bash -l {0}
92
+ run: |
93
+ python -m pip install matplotlib
94
+ pytest --no-cov -v --nbmake $(find examples -name '*.ipynb' ! -name 'use_featurizer_with_other_libraries.ipynb' ! -name 'shapley_value_with_customized_featurizers.ipynb')
95
+ pytest --no-cov -v --nbmake $(find docs/source/tutorial/python -name "*.ipynb")
96
+ pypi:
97
+ name: Build and publish Python 🐍 distributions 📦 to PyPI
98
+ runs-on: ubuntu-latest
99
+ # only run if the tests pass
100
+ needs: [test]
101
+ # run only on pushes to main on chemprop
102
+ if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' && github.repository == 'chemprop/chemprop'}}
103
+ steps:
104
+ - uses: actions/checkout@master
105
+ - name: Set up Python 3.12
106
+ uses: actions/setup-python@v3
107
+ with:
108
+ python-version: "3.11"
109
+ - name: Install pypa/build
110
+ run: >-
111
+ python -m
112
+ pip install
113
+ build
114
+ --user
115
+ - name: Build a binary wheel and a source tarball
116
+ run: >-
117
+ python -m
118
+ build
119
+ --sdist
120
+ --wheel
121
+ --outdir dist/
122
+ .
123
+ - name: Publish distribution 📦 to PyPI
124
+ uses: pypa/gh-action-pypi-publish@release/v1
125
+ with:
126
+ password: ${{ secrets.PYPI_API_TOKEN }}
127
+ skip-existing: true
128
+ verbose: true
129
+
130
+ build-and-push-docker:
131
+ # shamelessly copied from:
132
+ # https://github.com/ReactionMechanismGenerator/RMG-Py/blob/bfaee1cad9909a17103a8e6ef9a22569c475964c/.github/workflows/CI.yml#L359C1-L386C54
133
+ # which is also shamelessly copied from somewhere
134
+ runs-on: ubuntu-latest
135
+ # only run if the tests pass
136
+ needs: [test]
137
+ # run only on pushes to main on chemprop
138
+ if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' && github.repository == 'chemprop/chemprop'}}
139
+ steps:
140
+ - name: Set up QEMU
141
+ uses: docker/setup-qemu-action@v2
142
+
143
+ - name: Set up Docker Buildx
144
+ uses: docker/setup-buildx-action@v2
145
+
146
+ - name: Login to Docker Hub
147
+ uses: docker/login-action@v2
148
+ with:
149
+ # repository secretes managed by the maintainers
150
+ username: ${{ secrets.DOCKERHUB_USERNAME }}
151
+ password: ${{ secrets.DOCKERHUB_TOKEN }}
152
+
153
+ - name: Build and Push
154
+ uses: docker/build-push-action@v4
155
+ with:
156
+ push: true
157
+ tags: chemprop/chemprop:latest
158
+
chemprop-updated/.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ *.idea
163
+ *.DS_Store
164
+ *.vscode
165
+ *.csv
166
+ *.pkl
167
+ *.pt
168
+ *.json
169
+ *.sqlite3
170
+ *.yaml
171
+ *.tfevents.*
172
+ *.ckpt
173
+ chemprop/_version.py
174
+ *.ckpt
175
+ *.ipynb
176
+ config.toml
177
+
178
+ !tests/data/*
chemprop-updated/.readthedocs.yml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # .readthedocs.yml
2
+ # Read the Docs configuration file
3
+ # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
4
+
5
+ # Required
6
+ version: 2
7
+
8
+ # Set the OS, Python version and other tools you might need
9
+ build:
10
+ os: ubuntu-22.04
11
+ tools:
12
+ python: "3.11"
13
+ jobs:
14
+ post_install:
15
+ - python -m pip install --upgrade --upgrade-strategy only-if-needed --no-cache-dir ".[docs]"
16
+
17
+ # Build documentation in the docs/ directory with Sphinx
18
+ sphinx:
19
+ configuration: docs/source/conf.py
chemprop-updated/CITATIONS.bib ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # this was downloaded from ACS: https://pubs.acs.org/doi/10.1021/acs.jcim.9b00237
2
+ @article{chemprop_theory,
3
+ author = {Yang, Kevin and Swanson, Kyle and Jin, Wengong and Coley, Connor and Eiden, Philipp and Gao, Hua and Guzman-Perez, Angel and Hopper, Timothy and Kelley, Brian and Mathea, Miriam and Palmer, Andrew and Settels, Volker and Jaakkola, Tommi and Jensen, Klavs and Barzilay, Regina},
4
+ title = {Analyzing Learned Molecular Representations for Property Prediction},
5
+ journal = {Journal of Chemical Information and Modeling},
6
+ volume = {59},
7
+ number = {8},
8
+ pages = {3370-3388},
9
+ year = {2019},
10
+ doi = {10.1021/acs.jcim.9b00237},
11
+ note ={PMID: 31361484},
12
+ URL = {
13
+ https://doi.org/10.1021/acs.jcim.9b00237
14
+ },
15
+ eprint = {
16
+ https://doi.org/10.1021/acs.jcim.9b00237
17
+ }
18
+ }
19
+
20
+ # this was downloaded from ACS: https://pubs.acs.org/doi/10.1021/acs.jcim.3c01250
21
+ @article{chemprop_software,
22
+ author = {Heid, Esther and Greenman, Kevin P. and Chung, Yunsie and Li, Shih-Cheng and Graff, David E. and Vermeire, Florence H. and Wu, Haoyang and Green, William H. and McGill, Charles J.},
23
+ title = {Chemprop: A Machine Learning Package for Chemical Property Prediction},
24
+ journal = {Journal of Chemical Information and Modeling},
25
+ volume = {64},
26
+ number = {1},
27
+ pages = {9-17},
28
+ year = {2024},
29
+ doi = {10.1021/acs.jcim.3c01250},
30
+ note ={PMID: 38147829},
31
+ URL = {
32
+ https://doi.org/10.1021/acs.jcim.3c01250
33
+ },
34
+ eprint = {
35
+ https://doi.org/10.1021/acs.jcim.3c01250
36
+ }
37
+ }
chemprop-updated/CONTRIBUTING.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to contribute
2
+
3
+ We welcome contributions from external contributors, and this document
4
+ describes how to merge code changes into this repository.
5
+
6
+ ## Getting Started
7
+
8
+ * Make sure you have a [GitHub account](https://github.com/signup/free).
9
+ * [Fork](https://help.github.com/articles/fork-a-repo/) this repository on GitHub.
10
+ * On your local machine,
11
+ [clone](https://help.github.com/articles/cloning-a-repository/) your fork of
12
+ the repository.
13
+
14
+ ## Making Changes
15
+
16
+ * Add some really awesome code to your local fork. It's usually a [good
17
+ idea](http://blog.jasonmeridth.com/posts/do-not-issue-pull-requests-from-your-master-branch/)
18
+ to make changes on a
19
+ [branch](https://help.github.com/articles/creating-and-deleting-branches-within-your-repository/)
20
+ with the branch name relating to the feature you are going to add.
21
+ * When you are ready for others to examine and comment on your new feature,
22
+ navigate to your fork of `chemprop` on GitHub and open a [pull
23
+ request](https://help.github.com/articles/using-pull-requests/) (PR). Note that
24
+ after you launch a PR from one of your fork's branches, all
25
+ subsequent commits to that branch will be added to the open pull request
26
+ automatically. Each commit added to the PR will be validated for
27
+ mergability, compilation and test suite compliance; the results of these tests
28
+ will be visible on the PR page.
29
+ * If you're providing a new feature, you **must** add test cases and documentation.
30
+ * When the code is ready to go, run the test suite: `pytest`.
31
+ * When you're ready to be considered for merging, click the "Ready for review"
32
+ box on the PR page to let the Chemprop devs know that the changes are complete.
33
+ The code will not be merged until the continuous integration returns checkmarks,
34
+ and at least one core developer gives "Approved" reviews.
35
+
36
+ ## Additional Resources
37
+
38
+ * [General GitHub documentation](https://help.github.com/)
39
+ * [PR best practices](http://codeinthehole.com/writing/pull-requests-and-other-good-practices-for-teams-using-github/)
40
+ * [A guide to contributing to software packages](http://www.contribution-guide.org)
chemprop-updated/Dockerfile ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dockerfile
2
+ #
3
+ # Builds a Docker image containing Chemprop and its required dependencies.
4
+ #
5
+ # Build this image with:
6
+ # git clone https://github.com/chemprop/chemprop.git
7
+ # cd chemprop
8
+ # docker build --tag=chemprop:latest .
9
+ #
10
+ # Run the built image with:
11
+ # docker run --name chemprop_container -it chemprop:latest
12
+ #
13
+ # Note:
14
+ # This image only runs on CPU - we do not provide a Dockerfile
15
+ # for GPU use (see installation documentation).
16
+
17
+ # Parent Image
18
+ FROM continuumio/miniconda3:latest
19
+
20
+ # Install libxrender1 (required by RDKit) and then clean up
21
+ RUN apt-get update && \
22
+ apt-get install -y \
23
+ libxrender1 && \
24
+ apt-get autoremove -y && \
25
+ apt-get clean -y
26
+
27
+ WORKDIR /opt/chemprop
28
+
29
+ # build an empty conda environment with appropriate Python version
30
+ RUN conda create --name chemprop_env python=3.11*
31
+
32
+ # This runs all subsequent commands inside the chemprop_env conda environment
33
+ #
34
+ # Analogous to just activating the environment, which we can't actually do here
35
+ # since that requires running conda init and restarting the shell (not possible
36
+ # in a Dockerfile build script)
37
+ SHELL ["conda", "run", "--no-capture-output", "-n", "chemprop_env", "/bin/bash", "-c"]
38
+
39
+ # Follow the installation instructions then clear the cache
40
+ ADD chemprop chemprop
41
+ ENV PYTHONPATH /opt/chemprop
42
+ ADD LICENSE.txt pyproject.toml README.md ./
43
+ RUN conda install pytorch cpuonly -c pytorch && \
44
+ conda clean --all --yes && \
45
+ python -m pip install . && \
46
+ python -m pip cache purge
47
+
48
+ # when running this image, open an interactive bash terminal inside the conda environment
49
+ RUN echo "conda activate chemprop_env" > ~/.bashrc
50
+ ENTRYPOINT ["/bin/bash", "--login"]
chemprop-updated/LICENSE.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 The Chemprop Development Team (Regina Barzilay,
4
+ Jackson Burns, Yunsie Chung, Anna Doner, Xiaorui Dong, David Graff,
5
+ William Green, Kevin Greenman, Yanfei Guan, Esther Heid, Lior Hirschfeld,
6
+ Tommi Jaakkola, Wengong Jin, Olivier Lafontant-Joseph, Shih-Cheng Li,
7
+ Mengjie Liu, Joel Manu, Charles McGill, Angiras Menon, Nathan Morgan,
8
+ Hao-Wei Pang, Kevin Spiekermann, Kyle Swanson, Allison Tam,
9
+ Florence Vermeire, Haoyang Wu, and Kevin Yang, Jonathan Zheng)
10
+
11
+ Permission is hereby granted, free of charge, to any person obtaining a copy
12
+ of this software and associated documentation files (the "Software"), to deal
13
+ in the Software without restriction, including without limitation the rights
14
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
15
+ copies of the Software, and to permit persons to whom the Software is
16
+ furnished to do so, subject to the following conditions:
17
+
18
+ The above copyright notice and this permission notice shall be included in all
19
+ copies or substantial portions of the Software.
20
+
21
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
22
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
23
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
24
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
25
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
26
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
27
+ SOFTWARE.
chemprop-updated/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <picture>
2
+ <source media="(prefers-color-scheme: dark)" srcset="docs/source/_static/images/logo/chemprop_logo_dark_mode.svg">
3
+ <img alt="ChemProp Logo" src="docs/source/_static/images/logo/chemprop_logo.svg">
4
+ </picture>
5
+
6
+ # Chemprop
7
+
8
+ [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/chemprop)](https://badge.fury.io/py/chemprop)
9
+ [![PyPI version](https://badge.fury.io/py/chemprop.svg)](https://badge.fury.io/py/chemprop)
10
+ [![Anaconda-Server Badge](https://anaconda.org/conda-forge/chemprop/badges/version.svg)](https://anaconda.org/conda-forge/chemprop)
11
+ [![Build Status](https://github.com/chemprop/chemprop/workflows/tests/badge.svg)](https://github.com/chemprop/chemprop/actions/workflows/tests.yml)
12
+ [![Documentation Status](https://readthedocs.org/projects/chemprop/badge/?version=main)](https://chemprop.readthedocs.io/en/main/?badge=main)
13
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
14
+ [![Downloads](https://static.pepy.tech/badge/chemprop)](https://pepy.tech/project/chemprop)
15
+ [![Downloads](https://static.pepy.tech/badge/chemprop/month)](https://pepy.tech/project/chemprop)
16
+ [![Downloads](https://static.pepy.tech/badge/chemprop/week)](https://pepy.tech/project/chemprop)
17
+
18
+ Chemprop is a repository containing message passing neural networks for molecular property prediction.
19
+
20
+ Documentation can be found [here](https://chemprop.readthedocs.io/en/main/).
21
+
22
+ There are tutorial notebooks in the [`examples/`](https://github.com/chemprop/chemprop/tree/main/examples) directory.
23
+
24
+ Chemprop recently underwent a ground-up rewrite and new major release (v2.0.0). A helpful transition guide from Chemprop v1 to v2 can be found [here](https://docs.google.com/spreadsheets/u/3/d/e/2PACX-1vRshySIknVBBsTs5P18jL4WeqisxDAnDE5VRnzxqYEhYrMe4GLS17w5KeKPw9sged6TmmPZ4eEZSTIy/pubhtml). This includes a side-by-side comparison of CLI argument options, a list of which arguments will be implemented in later versions of v2, and a list of changes to default hyperparameters.
25
+
26
+ **License:** Chemprop is free to use under the [MIT License](LICENSE.txt). The Chemprop logo is free to use under [CC0 1.0](docs/source/_static/images/logo/LICENSE.txt).
27
+
28
+ **References**: Please cite the appropriate papers if Chemprop is helpful to your research.
29
+
30
+ - Chemprop was initially described in the papers [Analyzing Learned Molecular Representations for Property Prediction](https://pubs.acs.org/doi/abs/10.1021/acs.jcim.9b00237) for molecules and [Machine Learning of Reaction Properties via Learned Representations of the Condensed Graph of Reaction](https://doi.org/10.1021/acs.jcim.1c00975) for reactions.
31
+ - The interpretation functionality (available in v1, but not yet implemented in v2) is based on the paper [Multi-Objective Molecule Generation using Interpretable Substructures](https://arxiv.org/abs/2002.03244).
32
+ - Chemprop now has its own dedicated manuscript that describes and benchmarks it in more detail: [Chemprop: A Machine Learning Package for Chemical Property Prediction](https://doi.org/10.1021/acs.jcim.3c01250).
33
+ - A paper describing and benchmarking the changes in v2.0.0 is forthcoming.
34
+
35
+ **Selected Applications**: Chemprop has been successfully used in the following works.
36
+
37
+ - [A Deep Learning Approach to Antibiotic Discovery](https://www.cell.com/cell/fulltext/S0092-8674(20)30102-1) - _Cell_ (2020): Chemprop was used to predict antibiotic activity against _E. coli_, leading to the discovery of [Halicin](https://en.wikipedia.org/wiki/Halicin), a novel antibiotic candidate. Model checkpoints are availabile on [Zenodo](https://doi.org/10.5281/zenodo.6527882).
38
+ - [Discovery of a structural class of antibiotics with explainable deep learning](https://www.nature.com/articles/s41586-023-06887-8) - _Nature_ (2023): Identified a structural class of antibiotics selective against methicillin-resistant _S. aureus_ (MRSA) and vancomycin-resistant enterococci using ensembles of Chemprop models, and explained results using Chemprop's interpret method.
39
+ - [ADMET-AI: A machine learning ADMET platform for evaluation of large-scale chemical libraries](https://academic.oup.com/bioinformatics/advance-article/doi/10.1093/bioinformatics/btae416/7698030?utm_source=authortollfreelink&utm_campaign=bioinformatics&utm_medium=email&guestAccessKey=f4fca1d2-49ec-4b10-b476-5aea3bf37045): Chemprop was trained on 41 absorption, distribution, metabolism, excretion, and toxicity (ADMET) datasets from the [Therapeutics Data Commons](https://tdcommons.ai). The Chemprop models in ADMET-AI are available both as a web server at [admet.ai.greenstonebio.com](https://admet.ai.greenstonebio.com) and as a Python package at [github.com/swansonk14/admet_ai](https://github.com/swansonk14/admet_ai).
40
+ - A more extensive list of successful Chemprop applications is given in our [2023 paper](https://doi.org/10.1021/acs.jcim.3c01250)
41
+
42
+ ## Version 1.x
43
+
44
+ For users who have not yet made the switch to Chemprop v2.0, please reference the following resources.
45
+
46
+ ### v1 Documentation
47
+
48
+ - Documentation of Chemprop v1 is available [here](https://chemprop.readthedocs.io/en/v1.7.1/). Note that the content of this site is several versions behind the final v1 release (v1.7.1) and does not cover the full scope of features available in chemprop v1.
49
+ - The v1 [README](https://github.com/chemprop/chemprop/blob/v1.7.1/README.md) is the best source for documentation on more recently-added features.
50
+ - Please also see descriptions of all the possible command line arguments in the v1 [`args.py`](https://github.com/chemprop/chemprop/blob/v1.7.1/chemprop/args.py) file.
51
+
52
+ ### v1 Tutorials and Examples
53
+
54
+ - [Benchmark scripts](https://github.com/chemprop/chemprop_benchmark) - scripts from our 2023 paper, providing examples of many features using Chemprop v1.6.1
55
+ - [ACS Fall 2023 Workshop](https://github.com/chemprop/chemprop-workshop-acs-fall2023) - presentation, interactive demo, exercises on Google Colab with solution key
56
+ - [Google Colab notebook](https://colab.research.google.com/github/chemprop/chemprop/blob/v1.7.1/colab_demo.ipynb) - several examples, intended to be run in Google Colab rather than as a Jupyter notebook on your local machine
57
+ - [nanoHUB tool](https://nanohub.org/resources/chempropdemo/) - a notebook of examples similar to the Colab notebook above, doesn't require any installation
58
+ - [YouTube video](https://www.youtube.com/watch?v=TeOl5E8Wo2M) - lecture accompanying nanoHUB tool
59
+ - These [slides](https://docs.google.com/presentation/d/14pbd9LTXzfPSJHyXYkfLxnK8Q80LhVnjImg8a3WqCRM/edit?usp=sharing) provide a Chemprop tutorial and highlight additions as of April 28th, 2020
60
+
61
+ ### v1 Known Issues
62
+
63
+ We have discontinued support for v1 since v2 has been released, but we still appreciate v1 bug reports and will tag them as [`v1-wontfix`](https://github.com/chemprop/chemprop/issues?q=label%3Av1-wontfix+) so the community can find them easily.
chemprop-updated/chemprop/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import data, exceptions, featurizers, models, nn, schedulers, utils
2
+
3
+ __all__ = ["data", "featurizers", "models", "nn", "utils", "exceptions", "schedulers"]
4
+
5
+ __version__ = "2.1.2"
chemprop-updated/chemprop/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (743 Bytes). View file
 
chemprop-updated/chemprop/__pycache__/args.cpython-37.pyc ADDED
Binary file (33.7 kB). View file
 
chemprop-updated/chemprop/__pycache__/constants.cpython-37.pyc ADDED
Binary file (430 Bytes). View file
 
chemprop-updated/chemprop/__pycache__/hyperopt_utils.cpython-37.pyc ADDED
Binary file (11.1 kB). View file
 
chemprop-updated/chemprop/__pycache__/hyperparameter_optimization.cpython-37.pyc ADDED
Binary file (6.15 kB). View file
 
chemprop-updated/chemprop/__pycache__/interpret.cpython-37.pyc ADDED
Binary file (14.2 kB). View file
 
chemprop-updated/chemprop/__pycache__/multitask_utils.cpython-37.pyc ADDED
Binary file (3.12 kB). View file
 
chemprop-updated/chemprop/__pycache__/nn_utils.cpython-37.pyc ADDED
Binary file (8.13 kB). View file
 
chemprop-updated/chemprop/__pycache__/rdkit.cpython-37.pyc ADDED
Binary file (1.43 kB). View file
 
chemprop-updated/chemprop/__pycache__/sklearn_predict.cpython-37.pyc ADDED
Binary file (2.82 kB). View file
 
chemprop-updated/chemprop/__pycache__/sklearn_train.cpython-37.pyc ADDED
Binary file (11.4 kB). View file
 
chemprop-updated/chemprop/__pycache__/spectra_utils.cpython-37.pyc ADDED
Binary file (5.1 kB). View file
 
chemprop-updated/chemprop/__pycache__/utils.cpython-37.pyc ADDED
Binary file (26.5 kB). View file
 
chemprop-updated/chemprop/cli/common.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+
5
+ from chemprop.cli.utils import LookupAction
6
+ from chemprop.cli.utils.args import uppercase
7
+ from chemprop.featurizers import AtomFeatureMode, MoleculeFeaturizerRegistry, RxnMode
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ def add_common_args(parser: ArgumentParser) -> ArgumentParser:
13
+ data_args = parser.add_argument_group("Shared input data args")
14
+ data_args.add_argument(
15
+ "-s",
16
+ "--smiles-columns",
17
+ nargs="+",
18
+ help="Column names in the input CSV containing SMILES strings (uses the 0th column by default)",
19
+ )
20
+ data_args.add_argument(
21
+ "-r",
22
+ "--reaction-columns",
23
+ nargs="+",
24
+ help="Column names in the input CSV containing reaction SMILES in the format ``REACTANT>AGENT>PRODUCT``, where 'AGENT' is optional",
25
+ )
26
+ data_args.add_argument(
27
+ "--no-header-row",
28
+ action="store_true",
29
+ help="Turn off using the first row in the input CSV as column names",
30
+ )
31
+
32
+ dataloader_args = parser.add_argument_group("Dataloader args")
33
+ dataloader_args.add_argument(
34
+ "-n",
35
+ "--num-workers",
36
+ type=int,
37
+ default=0,
38
+ help="""Number of workers for parallel data loading where 0 means sequential
39
+ (Warning: setting ``num_workers`` to a value greater than 0 can cause hangs on Windows and MacOS)""",
40
+ )
41
+ dataloader_args.add_argument("-b", "--batch-size", type=int, default=64, help="Batch size")
42
+
43
+ parser.add_argument(
44
+ "--accelerator", default="auto", help="Passed directly to the lightning ``Trainer()``"
45
+ )
46
+ parser.add_argument(
47
+ "--devices",
48
+ default="auto",
49
+ help="Passed directly to the lightning ``Trainer()`` (must be a single string of comma separated devices, e.g. '1, 2' if specifying multiple devices)",
50
+ )
51
+
52
+ featurization_args = parser.add_argument_group("Featurization args")
53
+ featurization_args.add_argument(
54
+ "--rxn-mode",
55
+ "--reaction-mode",
56
+ type=uppercase,
57
+ default="REAC_DIFF",
58
+ choices=list(RxnMode.keys()),
59
+ help="""Choices for construction of atom and bond features for reactions (case insensitive):
60
+
61
+ - ``REAC_PROD``: concatenates the reactants feature with the products feature
62
+ - ``REAC_DIFF``: concatenates the reactants feature with the difference in features between reactants and products (Default)
63
+ - ``PROD_DIFF``: concatenates the products feature with the difference in features between reactants and products
64
+ - ``REAC_PROD_BALANCE``: concatenates the reactants feature with the products feature, balances imbalanced reactions
65
+ - ``REAC_DIFF_BALANCE``: concatenates the reactants feature with the difference in features between reactants and products, balances imbalanced reactions
66
+ - ``PROD_DIFF_BALANCE``: concatenates the products feature with the difference in features between reactants and products, balances imbalanced reactions""",
67
+ )
68
+ # TODO: Update documenation for multi_hot_atom_featurizer_mode
69
+ featurization_args.add_argument(
70
+ "--multi-hot-atom-featurizer-mode",
71
+ type=uppercase,
72
+ default="V2",
73
+ choices=list(AtomFeatureMode.keys()),
74
+ help="""Choices for multi-hot atom featurization scheme. This will affect both non-reaction and reaction feturization (case insensitive):
75
+
76
+ - ``V1``: Corresponds to the original configuration employed in the Chemprop V1
77
+ - ``V2``: Tailored for a broad range of molecules, this configuration encompasses all elements in the first four rows of the periodic table, along with iodine. It is the default in Chemprop V2.
78
+ - ``ORGANIC``: This configuration is designed specifically for use with organic molecules for drug research and development and includes a subset of elements most common in organic chemistry, including H, B, C, N, O, F, Si, P, S, Cl, Br, and I.
79
+ - ``RIGR``: Modified V2 (default) featurizer using only the resonance-invariant atom and bond features.""",
80
+ )
81
+ featurization_args.add_argument(
82
+ "--keep-h",
83
+ action="store_true",
84
+ help="Whether hydrogens explicitly specified in input should be kept in the mol graph",
85
+ )
86
+ featurization_args.add_argument(
87
+ "--add-h", action="store_true", help="Whether hydrogens should be added to the mol graph"
88
+ )
89
+ featurization_args.add_argument(
90
+ "--molecule-featurizers",
91
+ "--features-generators",
92
+ nargs="+",
93
+ action=LookupAction(MoleculeFeaturizerRegistry),
94
+ help="Method(s) of generating molecule features to use as extra descriptors",
95
+ )
96
+ # TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
97
+ # featurization_args.add_argument(
98
+ # "--features-generators", nargs="+", help="Renamed to `--molecule-featurizers`."
99
+ # )
100
+ featurization_args.add_argument(
101
+ "--descriptors-path",
102
+ type=Path,
103
+ help="Path to extra descriptors to concatenate to learned representation",
104
+ )
105
+ # TODO: Add in v2.1
106
+ # featurization_args.add_argument(
107
+ # "--phase-features-path",
108
+ # help="Path to features used to indicate the phase of the data in one-hot vector form. Used in spectra datatype.",
109
+ # )
110
+ featurization_args.add_argument(
111
+ "--no-descriptor-scaling", action="store_true", help="Turn off extra descriptor scaling"
112
+ )
113
+ featurization_args.add_argument(
114
+ "--no-atom-feature-scaling", action="store_true", help="Turn off extra atom feature scaling"
115
+ )
116
+ featurization_args.add_argument(
117
+ "--no-atom-descriptor-scaling",
118
+ action="store_true",
119
+ help="Turn off extra atom descriptor scaling",
120
+ )
121
+ featurization_args.add_argument(
122
+ "--no-bond-feature-scaling", action="store_true", help="Turn off extra bond feature scaling"
123
+ )
124
+ featurization_args.add_argument(
125
+ "--atom-features-path",
126
+ nargs="+",
127
+ action="append",
128
+ help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom features to supply before message passing (e.g., ``--atom-features-path 0 /path/to/features_0.npz``) indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-features-path [...] --atom-features-path [...]``).",
129
+ )
130
+ featurization_args.add_argument(
131
+ "--atom-descriptors-path",
132
+ nargs="+",
133
+ action="append",
134
+ help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional atom descriptors to supply after message passing (e.g., ``--atom-descriptors-path 0 /path/to/descriptors_0.npz`` indicates that the descriptors at the given path should be supplied to the 0-th component. To supply additional descriptors for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--atom-descriptors-path [...] --atom-descriptors-path [...]``).",
135
+ )
136
+ featurization_args.add_argument(
137
+ "--bond-features-path",
138
+ nargs="+",
139
+ action="append",
140
+ help="If a single path is given, it is assumed to correspond to the 0-th molecule. Alternatively, it can be a two-tuple of molecule index and path to additional bond features to supply before message passing (e.g., ``--bond-features-path 0 /path/to/features_0.npz`` indicates that the features at the given path should be supplied to the 0-th component. To supply additional features for multiple components, repeat this argument on the command line for each component's respective values (e.g., ``--bond-features-path [...] --bond-features-path [...]``).",
141
+ )
142
+ # TODO: Add in v2.2
143
+ # parser.add_argument(
144
+ # "--constraints-path",
145
+ # help="Path to constraints applied to atomic/bond properties prediction.",
146
+ # )
147
+
148
+ return parser
149
+
150
+
151
+ def process_common_args(args: Namespace) -> Namespace:
152
+ # TODO: add in v2.1 to deprecate features-generators and then remove in v2.2
153
+ # if args.features_generators is not None:
154
+ # raise ArgumentError(
155
+ # argument=None,
156
+ # message="`--features-generators` has been renamed to `--molecule-featurizers`.",
157
+ # )
158
+
159
+ for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
160
+ inds_paths = getattr(args, key)
161
+
162
+ if not inds_paths:
163
+ continue
164
+
165
+ ind_path_dict = {}
166
+
167
+ for ind_path in inds_paths:
168
+ if len(ind_path) > 2:
169
+ raise ArgumentError(
170
+ argument=None,
171
+ message="Too many arguments were given for atom features/descriptors or bond features. It should be either a two-tuple of molecule index and a path, or a single path (assumed to be the 0-th molecule).",
172
+ )
173
+
174
+ if len(ind_path) == 1:
175
+ ind = 0
176
+ path = ind_path[0]
177
+ else:
178
+ ind, path = ind_path
179
+
180
+ if ind_path_dict.get(int(ind), None):
181
+ raise ArgumentError(
182
+ argument=None,
183
+ message=f"Duplicate atom features/descriptors or bond features given for molecule index {ind}",
184
+ )
185
+
186
+ ind_path_dict[int(ind)] = Path(path)
187
+
188
+ setattr(args, key, ind_path_dict)
189
+
190
+ return args
191
+
192
+
193
+ def validate_common_args(args):
194
+ pass
195
+
196
+
197
+ def find_models(model_paths: list[Path]):
198
+ collected_model_paths = []
199
+
200
+ for model_path in model_paths:
201
+ if model_path.suffix in [".ckpt", ".pt"]:
202
+ collected_model_paths.append(model_path)
203
+ elif model_path.is_dir():
204
+ collected_model_paths.extend(list(model_path.rglob("*.pt")))
205
+ else:
206
+ raise ArgumentError(
207
+ argument=None,
208
+ message=f"Expected a .ckpt or .pt file, or a directory. Got {model_path}",
209
+ )
210
+
211
+ return collected_model_paths
chemprop-updated/chemprop/cli/conf.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+
6
+ LOG_DIR = Path(os.getenv("CHEMPROP_LOG_DIR", "chemprop_logs"))
7
+ LOG_LEVELS = {0: logging.INFO, 1: logging.DEBUG, -1: logging.WARNING, -2: logging.ERROR}
8
+ NOW = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
9
+ CHEMPROP_TRAIN_DIR = Path(os.getenv("CHEMPROP_TRAIN_DIR", "chemprop_training"))
chemprop-updated/chemprop/cli/convert.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+ import sys
5
+
6
+ from chemprop.cli.utils import Subcommand
7
+ from chemprop.utils.v1_to_v2 import convert_model_file_v1_to_v2
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ConvertSubcommand(Subcommand):
13
+ COMMAND = "convert"
14
+ HELP = "Convert a v1 model checkpoint (.pt) to a v2 model checkpoint (.pt)."
15
+
16
+ @classmethod
17
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
18
+ parser.add_argument(
19
+ "-i",
20
+ "--input-path",
21
+ required=True,
22
+ type=Path,
23
+ help="Path to a v1 model .pt checkpoint file",
24
+ )
25
+ parser.add_argument(
26
+ "-o",
27
+ "--output-path",
28
+ type=Path,
29
+ help="Path to which the converted model will be saved (``CURRENT_DIRECTORY/STEM_OF_INPUT_v2.pt`` by default)",
30
+ )
31
+ return parser
32
+
33
+ @classmethod
34
+ def func(cls, args: Namespace):
35
+ if args.output_path is None:
36
+ args.output_path = Path(args.input_path.stem + "_v2.pt")
37
+ if args.output_path.suffix != ".pt":
38
+ raise ArgumentError(
39
+ argument=None, message=f"Output must be a `.pt` file. Got {args.output_path}"
40
+ )
41
+
42
+ logger.info(
43
+ f"Converting v1 model checkpoint '{args.input_path}' to v2 model checkpoint '{args.output_path}'..."
44
+ )
45
+ convert_model_file_v1_to_v2(args.input_path, args.output_path)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ parser = ArgumentParser()
50
+ parser = ConvertSubcommand.add_args(parser)
51
+
52
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
53
+
54
+ args = parser.parse_args()
55
+ ConvertSubcommand.func(args)
chemprop-updated/chemprop/cli/fingerprint.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+ import sys
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+
10
+ from chemprop import data
11
+ from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
12
+ from chemprop.cli.predict import find_models
13
+ from chemprop.cli.utils import Subcommand, build_data_from_files, make_dataset
14
+ from chemprop.models import load_model
15
+ from chemprop.nn.metrics import LossFunctionRegistry
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class FingerprintSubcommand(Subcommand):
21
+ COMMAND = "fingerprint"
22
+ HELP = "Use a pretrained chemprop model to calculate learned representations."
23
+
24
+ @classmethod
25
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
26
+ parser = add_common_args(parser)
27
+ parser.add_argument(
28
+ "-i",
29
+ "--test-path",
30
+ required=True,
31
+ type=Path,
32
+ help="Path to an input CSV file containing SMILES",
33
+ )
34
+ parser.add_argument(
35
+ "-o",
36
+ "--output",
37
+ "--preds-path",
38
+ type=Path,
39
+ help="Specify the path where predictions will be saved. If the file extension is .npz, they will be saved as a npz file. Otherwise, the predictions will be saved as a CSV. The index of the model will be appended to the filename's stem. By default, predictions will be saved to the same location as ``--test-path`` with '_fps' appended (e.g., 'PATH/TO/TEST_PATH_fps_0.csv').",
40
+ )
41
+ parser.add_argument(
42
+ "--model-paths",
43
+ "--model-path",
44
+ required=True,
45
+ type=Path,
46
+ nargs="+",
47
+ help="Specify location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, chemprop will recursively search and predict on all found (.pt) models.",
48
+ )
49
+ parser.add_argument(
50
+ "--ffn-block-index",
51
+ required=True,
52
+ type=int,
53
+ default=-1,
54
+ help="The index indicates which linear layer returns the encoding in the FFN. An index of 0 denotes the post-aggregation representation through a 0-layer MLP, while an index of 1 represents the output from the first linear layer in the FFN, and so forth.",
55
+ )
56
+
57
+ return parser
58
+
59
+ @classmethod
60
+ def func(cls, args: Namespace):
61
+ args = process_common_args(args)
62
+ validate_common_args(args)
63
+ args = process_fingerprint_args(args)
64
+ main(args)
65
+
66
+
67
+ def process_fingerprint_args(args: Namespace) -> Namespace:
68
+ if args.test_path.suffix not in [".csv"]:
69
+ raise ArgumentError(
70
+ argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
71
+ )
72
+ if args.output is None:
73
+ args.output = args.test_path.parent / (args.test_path.stem + "_fps.csv")
74
+ if args.output.suffix not in [".csv", ".npz"]:
75
+ raise ArgumentError(
76
+ argument=None, message=f"Output must be a CSV or NPZ file. Got '{args.output}'."
77
+ )
78
+ return args
79
+
80
+
81
+ def make_fingerprint_for_model(
82
+ args: Namespace, model_path: Path, multicomponent: bool, output_path: Path
83
+ ):
84
+ model = load_model(model_path, multicomponent)
85
+ model.eval()
86
+
87
+ bounded = any(
88
+ isinstance(model.criterion, LossFunctionRegistry[loss_function])
89
+ for loss_function in LossFunctionRegistry.keys()
90
+ if "bounded" in loss_function
91
+ )
92
+
93
+ format_kwargs = dict(
94
+ no_header_row=args.no_header_row,
95
+ smiles_cols=args.smiles_columns,
96
+ rxn_cols=args.reaction_columns,
97
+ target_cols=[],
98
+ ignore_cols=None,
99
+ splits_col=None,
100
+ weight_col=None,
101
+ bounded=bounded,
102
+ )
103
+
104
+ featurization_kwargs = dict(
105
+ molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
106
+ )
107
+
108
+ test_data = build_data_from_files(
109
+ args.test_path,
110
+ **format_kwargs,
111
+ p_descriptors=args.descriptors_path,
112
+ p_atom_feats=args.atom_features_path,
113
+ p_bond_feats=args.bond_features_path,
114
+ p_atom_descs=args.atom_descriptors_path,
115
+ **featurization_kwargs,
116
+ )
117
+ logger.info(f"test size: {len(test_data[0])}")
118
+ test_dsets = [
119
+ make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in test_data
120
+ ]
121
+
122
+ if multicomponent:
123
+ test_dset = data.MulticomponentDataset(test_dsets)
124
+ else:
125
+ test_dset = test_dsets[0]
126
+
127
+ test_loader = data.build_dataloader(test_dset, args.batch_size, args.num_workers, shuffle=False)
128
+
129
+ logger.info(model)
130
+
131
+ with torch.no_grad():
132
+ if multicomponent:
133
+ encodings = [
134
+ model.encoding(batch.bmgs, batch.V_ds, batch.X_d, args.ffn_block_index)
135
+ for batch in test_loader
136
+ ]
137
+ else:
138
+ encodings = [
139
+ model.encoding(batch.bmg, batch.V_d, batch.X_d, args.ffn_block_index)
140
+ for batch in test_loader
141
+ ]
142
+ H = torch.cat(encodings, 0).numpy()
143
+
144
+ if output_path.suffix in [".npz"]:
145
+ np.savez(output_path, H=H)
146
+ elif output_path.suffix == ".csv":
147
+ fingerprint_columns = [f"fp_{i}" for i in range(H.shape[1])]
148
+ df_fingerprints = pd.DataFrame(H, columns=fingerprint_columns)
149
+ df_fingerprints.to_csv(output_path, index=False)
150
+ else:
151
+ raise ArgumentError(
152
+ argument=None, message=f"Output must be a CSV or npz file. Got {args.output}."
153
+ )
154
+ logger.info(f"Fingerprints saved to '{output_path}'")
155
+
156
+
157
+ def main(args):
158
+ match (args.smiles_columns, args.reaction_columns):
159
+ case [None, None]:
160
+ n_components = 1
161
+ case [_, None]:
162
+ n_components = len(args.smiles_columns)
163
+ case [None, _]:
164
+ n_components = len(args.reaction_columns)
165
+ case _:
166
+ n_components = len(args.smiles_columns) + len(args.reaction_columns)
167
+
168
+ multicomponent = n_components > 1
169
+
170
+ for i, model_path in enumerate(find_models(args.model_paths)):
171
+ logger.info(f"Fingerprints with model {i} at '{model_path}'")
172
+ output_path = args.output.parent / f"{args.output.stem}_{i}{args.output.suffix}"
173
+ make_fingerprint_for_model(args, model_path, multicomponent, output_path)
174
+
175
+
176
+ if __name__ == "__main__":
177
+ parser = ArgumentParser()
178
+ parser = FingerprintSubcommand.add_args(parser)
179
+
180
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
181
+ args = parser.parse_args()
182
+ args = FingerprintSubcommand.func(args)
chemprop-updated/chemprop/cli/hpopt.py ADDED
@@ -0,0 +1,537 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import logging
3
+ from pathlib import Path
4
+ import shutil
5
+ import sys
6
+
7
+ from configargparse import ArgumentParser, Namespace
8
+ from lightning import pytorch as pl
9
+ from lightning.pytorch.callbacks import EarlyStopping
10
+ import numpy as np
11
+ import torch
12
+
13
+ from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
14
+ from chemprop.cli.train import (
15
+ TrainSubcommand,
16
+ add_train_args,
17
+ build_datasets,
18
+ build_model,
19
+ build_splits,
20
+ normalize_inputs,
21
+ process_train_args,
22
+ save_config,
23
+ validate_train_args,
24
+ )
25
+ from chemprop.cli.utils.command import Subcommand
26
+ from chemprop.data import build_dataloader
27
+ from chemprop.nn import AggregationRegistry, MetricRegistry
28
+ from chemprop.nn.transforms import UnscaleTransform
29
+ from chemprop.nn.utils import Activation
30
+
31
+ NO_RAY = False
32
+ DEFAULT_SEARCH_SPACE = {
33
+ "activation": None,
34
+ "aggregation": None,
35
+ "aggregation_norm": None,
36
+ "batch_size": None,
37
+ "depth": None,
38
+ "dropout": None,
39
+ "ffn_hidden_dim": None,
40
+ "ffn_num_layers": None,
41
+ "final_lr_ratio": None,
42
+ "message_hidden_dim": None,
43
+ "init_lr_ratio": None,
44
+ "max_lr": None,
45
+ "warmup_epochs": None,
46
+ }
47
+
48
+ try:
49
+ import ray
50
+ from ray import tune
51
+ from ray.train import CheckpointConfig, RunConfig, ScalingConfig
52
+ from ray.train.lightning import (
53
+ RayDDPStrategy,
54
+ RayLightningEnvironment,
55
+ RayTrainReportCallback,
56
+ prepare_trainer,
57
+ )
58
+ from ray.train.torch import TorchTrainer
59
+ from ray.tune.schedulers import ASHAScheduler, FIFOScheduler
60
+
61
+ DEFAULT_SEARCH_SPACE = {
62
+ "activation": tune.choice(categories=list(Activation.keys())),
63
+ "aggregation": tune.choice(categories=list(AggregationRegistry.keys())),
64
+ "aggregation_norm": tune.quniform(lower=1, upper=200, q=1),
65
+ "batch_size": tune.choice([16, 32, 64, 128, 256]),
66
+ "depth": tune.qrandint(lower=2, upper=6, q=1),
67
+ "dropout": tune.choice([0.0] * 8 + list(np.arange(0.05, 0.45, 0.05))),
68
+ "ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
69
+ "ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
70
+ "final_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
71
+ "message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
72
+ "init_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
73
+ "max_lr": tune.loguniform(lower=1e-4, upper=1e-2),
74
+ "warmup_epochs": None,
75
+ }
76
+ except ImportError:
77
+ NO_RAY = True
78
+
79
+ NO_HYPEROPT = False
80
+ try:
81
+ from ray.tune.search.hyperopt import HyperOptSearch
82
+ except ImportError:
83
+ NO_HYPEROPT = True
84
+
85
+ NO_OPTUNA = False
86
+ try:
87
+ from ray.tune.search.optuna import OptunaSearch
88
+ except ImportError:
89
+ NO_OPTUNA = True
90
+
91
+
92
+ logger = logging.getLogger(__name__)
93
+
94
+ SEARCH_SPACE = DEFAULT_SEARCH_SPACE
95
+
96
+ SEARCH_PARAM_KEYWORDS_MAP = {
97
+ "basic": ["depth", "ffn_num_layers", "dropout", "ffn_hidden_dim", "message_hidden_dim"],
98
+ "learning_rate": ["max_lr", "init_lr_ratio", "final_lr_ratio", "warmup_epochs"],
99
+ "all": list(DEFAULT_SEARCH_SPACE.keys()),
100
+ "init_lr": ["init_lr_ratio"],
101
+ "final_lr": ["final_lr_ratio"],
102
+ }
103
+
104
+
105
+ class HpoptSubcommand(Subcommand):
106
+ COMMAND = "hpopt"
107
+ HELP = "Perform hyperparameter optimization on the given task."
108
+
109
+ @classmethod
110
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
111
+ parser = add_common_args(parser)
112
+ parser = add_train_args(parser)
113
+ return add_hpopt_args(parser)
114
+
115
+ @classmethod
116
+ def func(cls, args: Namespace):
117
+ args = process_common_args(args)
118
+ args = process_train_args(args)
119
+ args = process_hpopt_args(args)
120
+ validate_common_args(args)
121
+ validate_train_args(args)
122
+ main(args)
123
+
124
+
125
+ def add_hpopt_args(parser: ArgumentParser) -> ArgumentParser:
126
+ hpopt_args = parser.add_argument_group("Chemprop hyperparameter optimization arguments")
127
+
128
+ hpopt_args.add_argument(
129
+ "--search-parameter-keywords",
130
+ type=str,
131
+ nargs="+",
132
+ default=["basic"],
133
+ help=f"""The model parameters over which to search for an optimal hyperparameter configuration. Some options are bundles of parameters or otherwise special parameter operations. Special keywords include:
134
+ - ``basic``: Default set of hyperparameters for search (depth, ffn_num_layers, dropout, message_hidden_dim, and ffn_hidden_dim)
135
+ - ``learning_rate``: Search for max_lr, init_lr_ratio, final_lr_ratio, and warmup_epochs. The search for init_lr and final_lr values are defined as fractions of the max_lr value. The search for warmup_epochs is as a fraction of the total epochs used.
136
+ - ``all``: Include search for all 13 individual keyword options (including: activation, aggregation, aggregation_norm, and batch_size which aren't included in the other two keywords).
137
+ Individual supported parameters:
138
+ {list(DEFAULT_SEARCH_SPACE.keys())}
139
+ """,
140
+ )
141
+
142
+ hpopt_args.add_argument(
143
+ "--hpopt-save-dir",
144
+ type=Path,
145
+ help="Directory to save the hyperparameter optimization results",
146
+ )
147
+
148
+ raytune_args = parser.add_argument_group("Ray Tune arguments")
149
+
150
+ raytune_args.add_argument(
151
+ "--raytune-num-samples",
152
+ type=int,
153
+ default=10,
154
+ help="Passed directly to Ray Tune ``TuneConfig`` to control number of trials to run",
155
+ )
156
+
157
+ raytune_args.add_argument(
158
+ "--raytune-search-algorithm",
159
+ choices=["random", "hyperopt", "optuna"],
160
+ default="hyperopt",
161
+ help="Passed to Ray Tune ``TuneConfig`` to control search algorithm",
162
+ )
163
+
164
+ raytune_args.add_argument(
165
+ "--raytune-trial-scheduler",
166
+ choices=["FIFO", "AsyncHyperBand"],
167
+ default="FIFO",
168
+ help="Passed to Ray Tune ``TuneConfig`` to control trial scheduler",
169
+ )
170
+
171
+ raytune_args.add_argument(
172
+ "--raytune-num-workers",
173
+ type=int,
174
+ default=1,
175
+ help="Passed directly to Ray Tune ``ScalingConfig`` to control number of workers to use",
176
+ )
177
+
178
+ raytune_args.add_argument(
179
+ "--raytune-use-gpu",
180
+ action="store_true",
181
+ help="Passed directly to Ray Tune ``ScalingConfig`` to control whether to use GPUs",
182
+ )
183
+
184
+ raytune_args.add_argument(
185
+ "--raytune-num-checkpoints-to-keep",
186
+ type=int,
187
+ default=1,
188
+ help="Passed directly to Ray Tune ``CheckpointConfig`` to control number of checkpoints to keep",
189
+ )
190
+
191
+ raytune_args.add_argument(
192
+ "--raytune-grace-period",
193
+ type=int,
194
+ default=10,
195
+ help="Passed directly to Ray Tune ``ASHAScheduler`` to control grace period",
196
+ )
197
+
198
+ raytune_args.add_argument(
199
+ "--raytune-reduction-factor",
200
+ type=int,
201
+ default=2,
202
+ help="Passed directly to Ray Tune ``ASHAScheduler`` to control reduction factor",
203
+ )
204
+
205
+ raytune_args.add_argument(
206
+ "--raytune-temp-dir", help="Passed directly to Ray Tune init to control temporary directory"
207
+ )
208
+
209
+ raytune_args.add_argument(
210
+ "--raytune-num-cpus",
211
+ type=int,
212
+ help="Passed directly to Ray Tune init to control number of CPUs to use",
213
+ )
214
+
215
+ raytune_args.add_argument(
216
+ "--raytune-num-gpus",
217
+ type=int,
218
+ help="Passed directly to Ray Tune init to control number of GPUs to use",
219
+ )
220
+
221
+ raytune_args.add_argument(
222
+ "--raytune-max-concurrent-trials",
223
+ type=int,
224
+ help="Passed directly to Ray Tune TuneConfig to control maximum concurrent trials",
225
+ )
226
+
227
+ hyperopt_args = parser.add_argument_group("Hyperopt arguments")
228
+
229
+ hyperopt_args.add_argument(
230
+ "--hyperopt-n-initial-points",
231
+ type=int,
232
+ help="Passed directly to ``HyperOptSearch`` to control number of initial points to sample",
233
+ )
234
+
235
+ hyperopt_args.add_argument(
236
+ "--hyperopt-random-state-seed",
237
+ type=int,
238
+ default=None,
239
+ help="Passed directly to ``HyperOptSearch`` to control random state seed",
240
+ )
241
+
242
+ return parser
243
+
244
+
245
+ def process_hpopt_args(args: Namespace) -> Namespace:
246
+ if args.hpopt_save_dir is None:
247
+ args.hpopt_save_dir = Path(f"chemprop_hpopt/{args.data_path.stem}")
248
+
249
+ args.hpopt_save_dir.mkdir(exist_ok=True, parents=True)
250
+
251
+ search_parameters = set()
252
+
253
+ available_search_parameters = list(SEARCH_SPACE.keys()) + list(SEARCH_PARAM_KEYWORDS_MAP.keys())
254
+
255
+ for keyword in args.search_parameter_keywords:
256
+ if keyword not in available_search_parameters:
257
+ raise ValueError(
258
+ f"Search parameter keyword: {keyword} not in available options: {available_search_parameters}."
259
+ )
260
+
261
+ search_parameters.update(
262
+ SEARCH_PARAM_KEYWORDS_MAP[keyword]
263
+ if keyword in SEARCH_PARAM_KEYWORDS_MAP
264
+ else [keyword]
265
+ )
266
+
267
+ args.search_parameter_keywords = list(search_parameters)
268
+
269
+ if not args.hyperopt_n_initial_points:
270
+ args.hyperopt_n_initial_points = args.raytune_num_samples // 2
271
+
272
+ return args
273
+
274
+
275
+ def build_search_space(search_parameters: list[str], train_epochs: int) -> dict:
276
+ if "warmup_epochs" in search_parameters and SEARCH_SPACE.get("warmup_epochs", None) is None:
277
+ assert (
278
+ train_epochs >= 6
279
+ ), "Training epochs must be at least 6 to perform hyperparameter optimization for warmup_epochs."
280
+ SEARCH_SPACE["warmup_epochs"] = tune.qrandint(lower=1, upper=train_epochs // 2, q=1)
281
+
282
+ return {param: SEARCH_SPACE[param] for param in search_parameters}
283
+
284
+
285
+ def update_args_with_config(args: Namespace, config: dict) -> Namespace:
286
+ args = deepcopy(args)
287
+
288
+ for key, value in config.items():
289
+ match key:
290
+ case "final_lr_ratio":
291
+ setattr(args, "final_lr", value * config.get("max_lr", args.max_lr))
292
+
293
+ case "init_lr_ratio":
294
+ setattr(args, "init_lr", value * config.get("max_lr", args.max_lr))
295
+
296
+ case _:
297
+ assert key in args, f"Key: {key} not found in args."
298
+ setattr(args, key, value)
299
+
300
+ return args
301
+
302
+
303
+ def train_model(config, args, train_dset, val_dset, logger, output_transform, input_transforms):
304
+ args = update_args_with_config(args, config)
305
+
306
+ train_loader = build_dataloader(
307
+ train_dset, args.batch_size, args.num_workers, seed=args.data_seed
308
+ )
309
+ val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
310
+
311
+ seed = args.pytorch_seed if args.pytorch_seed is not None else torch.seed()
312
+
313
+ torch.manual_seed(seed)
314
+
315
+ model = build_model(args, train_loader.dataset, output_transform, input_transforms)
316
+ logger.info(model)
317
+
318
+ if args.tracking_metric == "val_loss":
319
+ T_tracking_metric = model.criterion.__class__
320
+ else:
321
+ T_tracking_metric = MetricRegistry[args.tracking_metric]
322
+ args.tracking_metric = "val/" + args.tracking_metric
323
+
324
+ monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
325
+ logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
326
+
327
+ patience = args.patience if args.patience is not None else args.epochs
328
+ early_stopping = EarlyStopping(args.tracking_metric, patience=patience, mode=monitor_mode)
329
+
330
+ trainer = pl.Trainer(
331
+ accelerator=args.accelerator,
332
+ devices=args.devices,
333
+ max_epochs=args.epochs,
334
+ gradient_clip_val=args.grad_clip,
335
+ strategy=RayDDPStrategy(),
336
+ callbacks=[RayTrainReportCallback(), early_stopping],
337
+ plugins=[RayLightningEnvironment()],
338
+ deterministic=args.pytorch_seed is not None,
339
+ )
340
+ trainer = prepare_trainer(trainer)
341
+ trainer.fit(model, train_loader, val_loader)
342
+
343
+
344
+ def tune_model(
345
+ args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
346
+ ):
347
+ match args.raytune_trial_scheduler:
348
+ case "FIFO":
349
+ scheduler = FIFOScheduler()
350
+ case "AsyncHyperBand":
351
+ scheduler = ASHAScheduler(
352
+ max_t=args.epochs,
353
+ grace_period=min(args.raytune_grace_period, args.epochs),
354
+ reduction_factor=args.raytune_reduction_factor,
355
+ )
356
+ case _:
357
+ raise ValueError(f"Invalid trial scheduler! got: {args.raytune_trial_scheduler}.")
358
+
359
+ resources_per_worker = {}
360
+ if args.raytune_num_cpus and args.raytune_max_concurrent_trials:
361
+ resources_per_worker["CPU"] = args.raytune_num_cpus / args.raytune_max_concurrent_trials
362
+ if args.raytune_num_gpus and args.raytune_max_concurrent_trials:
363
+ resources_per_worker["GPU"] = args.raytune_num_gpus / args.raytune_max_concurrent_trials
364
+ if not resources_per_worker:
365
+ resources_per_worker = None
366
+
367
+ if args.raytune_num_gpus:
368
+ use_gpu = True
369
+ else:
370
+ use_gpu = args.raytune_use_gpu
371
+
372
+ scaling_config = ScalingConfig(
373
+ num_workers=args.raytune_num_workers,
374
+ use_gpu=use_gpu,
375
+ resources_per_worker=resources_per_worker,
376
+ trainer_resources={"CPU": 0},
377
+ )
378
+
379
+ checkpoint_config = CheckpointConfig(
380
+ num_to_keep=args.raytune_num_checkpoints_to_keep,
381
+ checkpoint_score_attribute=args.tracking_metric,
382
+ checkpoint_score_order=monitor_mode,
383
+ )
384
+
385
+ run_config = RunConfig(
386
+ checkpoint_config=checkpoint_config,
387
+ storage_path=args.hpopt_save_dir.absolute() / "ray_results",
388
+ )
389
+
390
+ ray_trainer = TorchTrainer(
391
+ lambda config: train_model(
392
+ config, args, train_dset, val_dset, logger, output_transform, input_transforms
393
+ ),
394
+ scaling_config=scaling_config,
395
+ run_config=run_config,
396
+ )
397
+
398
+ match args.raytune_search_algorithm:
399
+ case "random":
400
+ search_alg = None
401
+ case "hyperopt":
402
+ if NO_HYPEROPT:
403
+ raise ImportError(
404
+ "HyperOptSearch requires hyperopt to be installed. Use 'pip install -U hyperopt' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
405
+ )
406
+
407
+ search_alg = HyperOptSearch(
408
+ n_initial_points=args.hyperopt_n_initial_points,
409
+ random_state_seed=args.hyperopt_random_state_seed,
410
+ )
411
+ case "optuna":
412
+ if NO_OPTUNA:
413
+ raise ImportError(
414
+ "OptunaSearch requires optuna to be installed. Use 'pip install -U optuna' to install or use 'pip install -e .[hpopt]' in chemprop folder if you installed from source to install all hpopt relevant packages."
415
+ )
416
+
417
+ search_alg = OptunaSearch()
418
+
419
+ tune_config = tune.TuneConfig(
420
+ metric=args.tracking_metric,
421
+ mode=monitor_mode,
422
+ num_samples=args.raytune_num_samples,
423
+ scheduler=scheduler,
424
+ search_alg=search_alg,
425
+ trial_dirname_creator=lambda trial: str(trial.trial_id),
426
+ )
427
+
428
+ tuner = tune.Tuner(
429
+ ray_trainer,
430
+ param_space={
431
+ "train_loop_config": build_search_space(args.search_parameter_keywords, args.epochs)
432
+ },
433
+ tune_config=tune_config,
434
+ )
435
+
436
+ return tuner.fit()
437
+
438
+
439
+ def main(args: Namespace):
440
+ if NO_RAY:
441
+ raise ImportError(
442
+ "Ray Tune requires ray to be installed. If you installed Chemprop from PyPI, run 'pip install -U ray[tune]' to install ray. If you installed from source, use 'pip install -e .[hpopt]' in Chemprop folder to install all hpopt relevant packages."
443
+ )
444
+
445
+ if not ray.is_initialized():
446
+ try:
447
+ ray.init(
448
+ _temp_dir=args.raytune_temp_dir,
449
+ num_cpus=args.raytune_num_cpus,
450
+ num_gpus=args.raytune_num_gpus,
451
+ )
452
+ except OSError as e:
453
+ if "AF_UNIX path length cannot exceed 107 bytes" in str(e):
454
+ raise OSError(
455
+ f"Ray Tune fails due to: {e}. This can sometimes be solved by providing a temporary directory, num_cpus, and num_gpus to Ray Tune via the CLI: --raytune-temp-dir <absolute_path> --raytune-num-cpus <int> --raytune-num-gpus <int>."
456
+ )
457
+ else:
458
+ raise e
459
+ else:
460
+ logger.info("Ray is already initialized.")
461
+
462
+ format_kwargs = dict(
463
+ no_header_row=args.no_header_row,
464
+ smiles_cols=args.smiles_columns,
465
+ rxn_cols=args.reaction_columns,
466
+ target_cols=args.target_columns,
467
+ ignore_cols=args.ignore_columns,
468
+ splits_col=args.splits_column,
469
+ weight_col=args.weight_column,
470
+ bounded=args.loss_function is not None and "bounded" in args.loss_function,
471
+ )
472
+
473
+ featurization_kwargs = dict(
474
+ molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
475
+ )
476
+
477
+ train_data, val_data, test_data = build_splits(args, format_kwargs, featurization_kwargs)
478
+ train_dset, val_dset, test_dset = build_datasets(args, train_data[0], val_data[0], test_data[0])
479
+
480
+ input_transforms = normalize_inputs(train_dset, val_dset, args)
481
+
482
+ if "regression" in args.task_type:
483
+ output_scaler = train_dset.normalize_targets()
484
+ val_dset.normalize_targets(output_scaler)
485
+ logger.info(f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}")
486
+ output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
487
+ else:
488
+ output_transform = None
489
+
490
+ train_loader = build_dataloader(
491
+ train_dset, args.batch_size, args.num_workers, seed=args.data_seed
492
+ )
493
+
494
+ model = build_model(args, train_loader.dataset, output_transform, input_transforms)
495
+ monitor_mode = "max" if model.metrics[0].higher_is_better else "min"
496
+
497
+ results = tune_model(
498
+ args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
499
+ )
500
+
501
+ best_result = results.get_best_result()
502
+ best_config = best_result.config["train_loop_config"]
503
+ best_checkpoint_path = Path(best_result.checkpoint.path) / "checkpoint.ckpt"
504
+
505
+ best_config_save_path = args.hpopt_save_dir / "best_config.toml"
506
+ best_checkpoint_save_path = args.hpopt_save_dir / "best_checkpoint.ckpt"
507
+ all_progress_save_path = args.hpopt_save_dir / "all_progress.csv"
508
+
509
+ logger.info(f"Best hyperparameters saved to: '{best_config_save_path}'")
510
+
511
+ args = update_args_with_config(args, best_config)
512
+
513
+ args = TrainSubcommand.parser.parse_known_args(namespace=args)[0]
514
+ save_config(TrainSubcommand.parser, args, best_config_save_path)
515
+
516
+ logger.info(
517
+ f"Best hyperparameter configuration checkpoint saved to '{best_checkpoint_save_path}'"
518
+ )
519
+
520
+ shutil.copyfile(best_checkpoint_path, best_checkpoint_save_path)
521
+
522
+ logger.info(f"Hyperparameter optimization results saved to '{all_progress_save_path}'")
523
+
524
+ result_df = results.get_dataframe()
525
+
526
+ result_df.to_csv(all_progress_save_path, index=False)
527
+
528
+ ray.shutdown()
529
+
530
+
531
+ if __name__ == "__main__":
532
+ parser = ArgumentParser()
533
+ parser = HpoptSubcommand.add_args(parser)
534
+
535
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
536
+ args = parser.parse_args()
537
+ HpoptSubcommand.func(args)
chemprop-updated/chemprop/cli/main.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ import sys
4
+
5
+ from configargparse import ArgumentParser
6
+
7
+ from chemprop.cli.conf import LOG_DIR, LOG_LEVELS, NOW
8
+ from chemprop.cli.convert import ConvertSubcommand
9
+ from chemprop.cli.fingerprint import FingerprintSubcommand
10
+ from chemprop.cli.hpopt import HpoptSubcommand
11
+ from chemprop.cli.predict import PredictSubcommand
12
+ from chemprop.cli.train import TrainSubcommand
13
+ from chemprop.cli.utils import pop_attr
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ SUBCOMMANDS = [
18
+ TrainSubcommand,
19
+ PredictSubcommand,
20
+ ConvertSubcommand,
21
+ FingerprintSubcommand,
22
+ HpoptSubcommand,
23
+ ]
24
+
25
+
26
+ def construct_parser():
27
+ parser = ArgumentParser()
28
+ subparsers = parser.add_subparsers(title="mode", dest="mode", required=True)
29
+
30
+ parent = ArgumentParser(add_help=False)
31
+ parent.add_argument(
32
+ "--logfile",
33
+ "--log",
34
+ nargs="?",
35
+ const="default",
36
+ help=f"Path to which the log file should be written (specifying just the flag alone will automatically log to a file ``{LOG_DIR}/MODE/TIMESTAMP.log`` , where 'MODE' is the CLI mode chosen, e.g., ``{LOG_DIR}/MODE/{NOW}.log``)",
37
+ )
38
+ parent.add_argument("-v", action="store_true", help="Increase verbosity level to DEBUG")
39
+ parent.add_argument(
40
+ "-q",
41
+ action="count",
42
+ default=0,
43
+ help="Decrease verbosity level to WARNING or ERROR if specified twice",
44
+ )
45
+
46
+ parents = [parent]
47
+ for subcommand in SUBCOMMANDS:
48
+ subcommand.add(subparsers, parents)
49
+
50
+ return parser
51
+
52
+
53
+ def main():
54
+ parser = construct_parser()
55
+ args = parser.parse_args()
56
+ logfile, v_flag, q_count, mode, func = (
57
+ pop_attr(args, attr) for attr in ["logfile", "v", "q", "mode", "func"]
58
+ )
59
+
60
+ if v_flag and q_count:
61
+ parser.error("The -v and -q options cannot be used together.")
62
+
63
+ match logfile:
64
+ case None:
65
+ handler = logging.StreamHandler(sys.stderr)
66
+ case "default":
67
+ (LOG_DIR / mode).mkdir(parents=True, exist_ok=True)
68
+ handler = logging.FileHandler(str(LOG_DIR / mode / f"{NOW}.log"))
69
+ case _:
70
+ Path(logfile).parent.mkdir(parents=True, exist_ok=True)
71
+ handler = logging.FileHandler(logfile)
72
+
73
+ verbosity = q_count * -1 if q_count else (1 if v_flag else 0)
74
+ logging_level = LOG_LEVELS.get(verbosity, logging.ERROR)
75
+ logging.basicConfig(
76
+ handlers=[handler],
77
+ format="%(asctime)s - %(levelname)s:%(name)s - %(message)s",
78
+ level=logging_level,
79
+ datefmt="%Y-%m-%dT%H:%M:%S",
80
+ force=True,
81
+ )
82
+
83
+ logger.info(f"Running in mode '{mode}' with args: {vars(args)}")
84
+
85
+ func(args)
chemprop-updated/chemprop/cli/predict.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentError, ArgumentParser, Namespace
2
+ import logging
3
+ from pathlib import Path
4
+ import sys
5
+ from typing import Iterator
6
+
7
+ from lightning import pytorch as pl
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+
12
+ from chemprop import data
13
+ from chemprop.cli.common import (
14
+ add_common_args,
15
+ find_models,
16
+ process_common_args,
17
+ validate_common_args,
18
+ )
19
+ from chemprop.cli.utils import LookupAction, Subcommand, build_data_from_files, make_dataset
20
+ from chemprop.models.utils import load_model, load_output_columns
21
+ from chemprop.nn.metrics import LossFunctionRegistry
22
+ from chemprop.nn.predictors import EvidentialFFN, MulticlassClassificationFFN, MveFFN
23
+ from chemprop.uncertainty import (
24
+ MVEWeightingCalibrator,
25
+ NoUncertaintyEstimator,
26
+ RegressionCalibrator,
27
+ RegressionEvaluator,
28
+ UncertaintyCalibratorRegistry,
29
+ UncertaintyEstimatorRegistry,
30
+ UncertaintyEvaluatorRegistry,
31
+ )
32
+ from chemprop.utils import Factory
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class PredictSubcommand(Subcommand):
38
+ COMMAND = "predict"
39
+ HELP = "use a pretrained chemprop model for prediction"
40
+
41
+ @classmethod
42
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
43
+ parser = add_common_args(parser)
44
+ return add_predict_args(parser)
45
+
46
+ @classmethod
47
+ def func(cls, args: Namespace):
48
+ args = process_common_args(args)
49
+ validate_common_args(args)
50
+ args = process_predict_args(args)
51
+ main(args)
52
+
53
+
54
+ def add_predict_args(parser: ArgumentParser) -> ArgumentParser:
55
+ parser.add_argument(
56
+ "-i",
57
+ "--test-path",
58
+ required=True,
59
+ type=Path,
60
+ help="Path to an input CSV file containing SMILES",
61
+ )
62
+ parser.add_argument(
63
+ "-o",
64
+ "--output",
65
+ "--preds-path",
66
+ type=Path,
67
+ help="Specify path to which predictions will be saved. If the file extension is .pkl, it will be saved as a pickle file. Otherwise, chemprop will save predictions as a CSV. If multiple models are used to make predictions, the average predictions will be saved in the file, and another file ending in '_individual' with the same file extension will save the predictions for each individual model, with the column names being the target names appended with the model index (e.g., '_model_<index>').",
68
+ )
69
+ parser.add_argument(
70
+ "--drop-extra-columns",
71
+ action="store_true",
72
+ help="Whether to drop all columns from the test data file besides the SMILES columns and the new prediction columns",
73
+ )
74
+ parser.add_argument(
75
+ "--model-paths",
76
+ "--model-path",
77
+ required=True,
78
+ type=Path,
79
+ nargs="+",
80
+ help="Location of checkpoint(s) or model file(s) to use for prediction. It can be a path to either a single pretrained model checkpoint (.ckpt) or single pretrained model file (.pt), a directory that contains these files, or a list of path(s) and directory(s). If a directory, will recursively search and predict on all found (.pt) models.",
81
+ )
82
+
83
+ unc_args = parser.add_argument_group("Uncertainty and calibration args")
84
+ unc_args.add_argument(
85
+ "--cal-path", type=Path, help="Path to data file to be used for uncertainty calibration."
86
+ )
87
+ unc_args.add_argument(
88
+ "--uncertainty-method",
89
+ default="none",
90
+ action=LookupAction(UncertaintyEstimatorRegistry),
91
+ help="The method of calculating uncertainty.",
92
+ )
93
+ unc_args.add_argument(
94
+ "--calibration-method",
95
+ action=LookupAction(UncertaintyCalibratorRegistry),
96
+ help="The method used for calibrating the uncertainty calculated with uncertainty method.",
97
+ )
98
+ unc_args.add_argument(
99
+ "--evaluation-methods",
100
+ "--evaluation-method",
101
+ nargs="+",
102
+ action=LookupAction(UncertaintyEvaluatorRegistry),
103
+ help="The methods used for evaluating the uncertainty performance if the test data provided includes targets. Available methods are [nll, miscalibration_area, ence, spearman] or any available classification or multiclass metric.",
104
+ )
105
+ # unc_args.add_argument(
106
+ # "--evaluation-scores-path", help="Location to save the results of uncertainty evaluations."
107
+ # )
108
+ unc_args.add_argument(
109
+ "--uncertainty-dropout-p",
110
+ type=float,
111
+ default=0.1,
112
+ help="The probability to use for Monte Carlo dropout uncertainty estimation.",
113
+ )
114
+ unc_args.add_argument(
115
+ "--dropout-sampling-size",
116
+ type=int,
117
+ default=10,
118
+ help="The number of samples to use for Monte Carlo dropout uncertainty estimation. Distinct from the dropout used during training.",
119
+ )
120
+ unc_args.add_argument(
121
+ "--calibration-interval-percentile",
122
+ type=float,
123
+ default=95,
124
+ help="Sets the percentile used in the calibration methods. Must be in the range (1, 100).",
125
+ )
126
+ unc_args.add_argument(
127
+ "--conformal-alpha",
128
+ type=float,
129
+ default=0.1,
130
+ help="Target error rate for conformal prediction. Must be in the range (0, 1).",
131
+ )
132
+ # TODO: Decide if we want to implment this in v2.1.x
133
+ # unc_args.add_argument(
134
+ # "--regression-calibrator-metric",
135
+ # choices=["stdev", "interval"],
136
+ # help="Regression calibrators can output either a stdev or an inverval.",
137
+ # )
138
+ unc_args.add_argument(
139
+ "--cal-descriptors-path",
140
+ nargs="+",
141
+ action="append",
142
+ help="Path to extra descriptors to concatenate to learned representation in calibration dataset.",
143
+ )
144
+ # TODO: Add in v2.1.x
145
+ # unc_args.add_argument(
146
+ # "--calibration-phase-features-path",
147
+ # help=" ",
148
+ # )
149
+ unc_args.add_argument(
150
+ "--cal-atom-features-path",
151
+ nargs="+",
152
+ action="append",
153
+ help="Path to the extra atom features in calibration dataset.",
154
+ )
155
+ unc_args.add_argument(
156
+ "--cal-atom-descriptors-path",
157
+ nargs="+",
158
+ action="append",
159
+ help="Path to the extra atom descriptors in calibration dataset.",
160
+ )
161
+ unc_args.add_argument(
162
+ "--cal-bond-features-path",
163
+ nargs="+",
164
+ action="append",
165
+ help="Path to the extra bond descriptors in calibration dataset.",
166
+ )
167
+
168
+ return parser
169
+
170
+
171
+ def process_predict_args(args: Namespace) -> Namespace:
172
+ if args.test_path.suffix not in [".csv"]:
173
+ raise ArgumentError(
174
+ argument=None, message=f"Input data must be a CSV file. Got {args.test_path}"
175
+ )
176
+ if args.output is None:
177
+ args.output = args.test_path.parent / (args.test_path.stem + "_preds.csv")
178
+ if args.output.suffix not in [".csv", ".pkl"]:
179
+ raise ArgumentError(
180
+ argument=None, message=f"Output must be a CSV or Pickle file. Got {args.output}"
181
+ )
182
+ return args
183
+
184
+
185
+ def prepare_data_loader(
186
+ args: Namespace, multicomponent: bool, is_calibration: bool, format_kwargs: dict
187
+ ):
188
+ data_path = args.cal_path if is_calibration else args.test_path
189
+ descriptors_path = args.cal_descriptors_path if is_calibration else args.descriptors_path
190
+ atom_feats_path = args.cal_atom_features_path if is_calibration else args.atom_features_path
191
+ bond_feats_path = args.cal_bond_features_path if is_calibration else args.bond_features_path
192
+ atom_descs_path = (
193
+ args.cal_atom_descriptors_path if is_calibration else args.atom_descriptors_path
194
+ )
195
+
196
+ featurization_kwargs = dict(
197
+ molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
198
+ )
199
+
200
+ datas = build_data_from_files(
201
+ data_path,
202
+ **format_kwargs,
203
+ p_descriptors=descriptors_path,
204
+ p_atom_feats=atom_feats_path,
205
+ p_bond_feats=bond_feats_path,
206
+ p_atom_descs=atom_descs_path,
207
+ **featurization_kwargs,
208
+ )
209
+
210
+ dsets = [make_dataset(d, args.rxn_mode, args.multi_hot_atom_featurizer_mode) for d in datas]
211
+ dset = data.MulticomponentDataset(dsets) if multicomponent else dsets[0]
212
+
213
+ return data.build_dataloader(dset, args.batch_size, args.num_workers, shuffle=False)
214
+
215
+
216
+ def make_prediction_for_models(
217
+ args: Namespace, model_paths: Iterator[Path], multicomponent: bool, output_path: Path
218
+ ):
219
+ model = load_model(model_paths[0], multicomponent)
220
+ output_columns = load_output_columns(model_paths[0])
221
+ bounded = any(
222
+ isinstance(model.criterion, LossFunctionRegistry[loss_function])
223
+ for loss_function in LossFunctionRegistry.keys()
224
+ if "bounded" in loss_function
225
+ )
226
+ format_kwargs = dict(
227
+ no_header_row=args.no_header_row,
228
+ smiles_cols=args.smiles_columns,
229
+ rxn_cols=args.reaction_columns,
230
+ ignore_cols=None,
231
+ splits_col=None,
232
+ weight_col=None,
233
+ bounded=bounded,
234
+ )
235
+ format_kwargs["target_cols"] = output_columns if args.evaluation_methods is not None else []
236
+ test_loader = prepare_data_loader(args, multicomponent, False, format_kwargs)
237
+ logger.info(f"test size: {len(test_loader.dataset)}")
238
+ if args.cal_path is not None:
239
+ format_kwargs["target_cols"] = output_columns
240
+ cal_loader = prepare_data_loader(args, multicomponent, True, format_kwargs)
241
+ logger.info(f"calibration size: {len(cal_loader.dataset)}")
242
+
243
+ uncertainty_estimator = Factory.build(
244
+ UncertaintyEstimatorRegistry[args.uncertainty_method],
245
+ ensemble_size=args.dropout_sampling_size,
246
+ dropout=args.uncertainty_dropout_p,
247
+ )
248
+
249
+ models = [load_model(model_path, multicomponent) for model_path in model_paths]
250
+ trainer = pl.Trainer(
251
+ logger=False, enable_progress_bar=True, accelerator=args.accelerator, devices=args.devices
252
+ )
253
+ test_individual_preds, test_individual_uncs = uncertainty_estimator(
254
+ test_loader, models, trainer
255
+ )
256
+ test_preds = torch.mean(test_individual_preds, dim=0)
257
+ if not isinstance(uncertainty_estimator, NoUncertaintyEstimator):
258
+ test_uncs = torch.mean(test_individual_uncs, dim=0)
259
+ else:
260
+ test_uncs = None
261
+
262
+ if args.calibration_method is not None:
263
+ uncertainty_calibrator = Factory.build(
264
+ UncertaintyCalibratorRegistry[args.calibration_method],
265
+ p=args.calibration_interval_percentile / 100,
266
+ alpha=args.conformal_alpha,
267
+ )
268
+ cal_targets = cal_loader.dataset.Y
269
+ cal_mask = torch.from_numpy(np.isfinite(cal_targets))
270
+ cal_targets = np.nan_to_num(cal_targets, nan=0.0)
271
+ cal_targets = torch.from_numpy(cal_targets)
272
+ cal_individual_preds, cal_individual_uncs = uncertainty_estimator(
273
+ cal_loader, models, trainer
274
+ )
275
+ cal_preds = torch.mean(cal_individual_preds, dim=0)
276
+ cal_uncs = torch.mean(cal_individual_uncs, dim=0)
277
+ if isinstance(uncertainty_calibrator, MVEWeightingCalibrator):
278
+ uncertainty_calibrator.fit(cal_preds, cal_individual_uncs, cal_targets, cal_mask)
279
+ test_uncs = uncertainty_calibrator.apply(cal_individual_uncs)
280
+ else:
281
+ if isinstance(uncertainty_calibrator, RegressionCalibrator):
282
+ uncertainty_calibrator.fit(cal_preds, cal_uncs, cal_targets, cal_mask)
283
+ else:
284
+ uncertainty_calibrator.fit(cal_uncs, cal_targets, cal_mask)
285
+ test_uncs = uncertainty_calibrator.apply(test_uncs)
286
+ for i in range(test_individual_uncs.shape[0]):
287
+ test_individual_uncs[i] = uncertainty_calibrator.apply(test_individual_uncs[i])
288
+
289
+ if args.evaluation_methods is not None:
290
+ uncertainty_evaluators = [
291
+ Factory.build(UncertaintyEvaluatorRegistry[method])
292
+ for method in args.evaluation_methods
293
+ ]
294
+ logger.info("Uncertainty evaluation metric:")
295
+ for evaluator in uncertainty_evaluators:
296
+ test_targets = test_loader.dataset.Y
297
+ test_mask = torch.from_numpy(np.isfinite(test_targets))
298
+ test_targets = np.nan_to_num(test_targets, nan=0.0)
299
+ test_targets = torch.from_numpy(test_targets)
300
+ if isinstance(evaluator, RegressionEvaluator):
301
+ metric_value = evaluator.evaluate(test_preds, test_uncs, test_targets, test_mask)
302
+ else:
303
+ metric_value = evaluator.evaluate(test_uncs, test_targets, test_mask)
304
+ logger.info(f"{evaluator.alias}: {metric_value.tolist()}")
305
+
306
+ if args.uncertainty_method == "none" and (
307
+ isinstance(model.predictor, MveFFN) or isinstance(model.predictor, EvidentialFFN)
308
+ ):
309
+ test_preds = test_preds[..., 0]
310
+ test_individual_preds = test_individual_preds[..., 0]
311
+
312
+ if output_columns is None:
313
+ output_columns = [
314
+ f"pred_{i}" for i in range(test_preds.shape[1])
315
+ ] # TODO: need to improve this for cases like multi-task MVE and multi-task multiclass
316
+
317
+ save_predictions(args, model, output_columns, test_preds, test_uncs, output_path)
318
+
319
+ if len(model_paths) > 1:
320
+ save_individual_predictions(
321
+ args,
322
+ model,
323
+ model_paths,
324
+ output_columns,
325
+ test_individual_preds,
326
+ test_individual_uncs,
327
+ output_path,
328
+ )
329
+
330
+
331
+ def save_predictions(args, model, output_columns, test_preds, test_uncs, output_path):
332
+ unc_columns = [f"{col}_unc" for col in output_columns]
333
+
334
+ if isinstance(model.predictor, MulticlassClassificationFFN):
335
+ output_columns = output_columns + [f"{col}_prob" for col in output_columns]
336
+ predicted_class_labels = test_preds.argmax(axis=-1)
337
+ formatted_probability_strings = np.apply_along_axis(
338
+ lambda x: ",".join(map(str, x)), 2, test_preds.numpy()
339
+ )
340
+ test_preds = np.concatenate(
341
+ (predicted_class_labels, formatted_probability_strings), axis=-1
342
+ )
343
+
344
+ df_test = pd.read_csv(
345
+ args.test_path, header=None if args.no_header_row else "infer", index_col=False
346
+ )
347
+ df_test[output_columns] = test_preds
348
+
349
+ if args.uncertainty_method not in ["none", "classification"]:
350
+ df_test[unc_columns] = np.round(test_uncs, 6)
351
+
352
+ if output_path.suffix == ".pkl":
353
+ df_test = df_test.reset_index(drop=True)
354
+ df_test.to_pickle(output_path)
355
+ else:
356
+ df_test.to_csv(output_path, index=False)
357
+ logger.info(f"Predictions saved to '{output_path}'")
358
+
359
+
360
+ def save_individual_predictions(
361
+ args,
362
+ model,
363
+ model_paths,
364
+ output_columns,
365
+ test_individual_preds,
366
+ test_individual_uncs,
367
+ output_path,
368
+ ):
369
+ unc_columns = [
370
+ f"{col}_unc_model_{i}" for i in range(len(model_paths)) for col in output_columns
371
+ ]
372
+
373
+ if isinstance(model.predictor, MulticlassClassificationFFN):
374
+ output_columns = [
375
+ item
376
+ for i in range(len(model_paths))
377
+ for col in output_columns
378
+ for item in (f"{col}_model_{i}", f"{col}_prob_model_{i}")
379
+ ]
380
+
381
+ predicted_class_labels = test_individual_preds.argmax(axis=-1)
382
+ formatted_probability_strings = np.apply_along_axis(
383
+ lambda x: ",".join(map(str, x)), 3, test_individual_preds.numpy()
384
+ )
385
+ test_individual_preds = np.concatenate(
386
+ (predicted_class_labels, formatted_probability_strings), axis=-1
387
+ )
388
+ else:
389
+ output_columns = [
390
+ f"{col}_model_{i}" for i in range(len(model_paths)) for col in output_columns
391
+ ]
392
+
393
+ m, n, t = test_individual_preds.shape
394
+ test_individual_preds = np.transpose(test_individual_preds, (1, 0, 2)).reshape(n, m * t)
395
+ df_test = pd.read_csv(
396
+ args.test_path, header=None if args.no_header_row else "infer", index_col=False
397
+ )
398
+ df_test[output_columns] = test_individual_preds
399
+
400
+ if args.uncertainty_method not in ["none", "classification", "ensemble"]:
401
+ m, n, t = test_individual_uncs.shape
402
+ test_individual_uncs = np.transpose(test_individual_uncs, (1, 0, 2)).reshape(n, m * t)
403
+ df_test[unc_columns] = np.round(test_individual_uncs, 6)
404
+
405
+ output_path = output_path.parent / Path(
406
+ str(args.output.stem) + "_individual" + str(output_path.suffix)
407
+ )
408
+ if output_path.suffix == ".pkl":
409
+ df_test = df_test.reset_index(drop=True)
410
+ df_test.to_pickle(output_path)
411
+ else:
412
+ df_test.to_csv(output_path, index=False)
413
+ logger.info(f"Individual predictions saved to '{output_path}'")
414
+ for i, model_path in enumerate(model_paths):
415
+ logger.info(
416
+ f"Results from model path {model_path} are saved under the column name ending with 'model_{i}'"
417
+ )
418
+
419
+
420
+ def main(args):
421
+ match (args.smiles_columns, args.reaction_columns):
422
+ case [None, None]:
423
+ n_components = 1
424
+ case [_, None]:
425
+ n_components = len(args.smiles_columns)
426
+ case [None, _]:
427
+ n_components = len(args.reaction_columns)
428
+ case _:
429
+ n_components = len(args.smiles_columns) + len(args.reaction_columns)
430
+
431
+ multicomponent = n_components > 1
432
+
433
+ model_paths = find_models(args.model_paths)
434
+
435
+ make_prediction_for_models(args, model_paths, multicomponent, output_path=args.output)
436
+
437
+
438
+ if __name__ == "__main__":
439
+ parser = ArgumentParser()
440
+ parser = PredictSubcommand.add_args(parser)
441
+
442
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
443
+ args = parser.parse_args()
444
+ args = PredictSubcommand.func(args)
chemprop-updated/chemprop/cli/train.py ADDED
@@ -0,0 +1,1340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ from io import StringIO
3
+ import json
4
+ import logging
5
+ from pathlib import Path
6
+ import sys
7
+ from tempfile import TemporaryDirectory
8
+
9
+ from configargparse import ArgumentError, ArgumentParser, Namespace
10
+ from lightning import pytorch as pl
11
+ from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
12
+ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
13
+ from lightning.pytorch.strategies import DDPStrategy
14
+ import numpy as np
15
+ import pandas as pd
16
+ from rich.console import Console
17
+ from rich.table import Column, Table
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from chemprop.cli.common import (
22
+ add_common_args,
23
+ find_models,
24
+ process_common_args,
25
+ validate_common_args,
26
+ )
27
+ from chemprop.cli.conf import CHEMPROP_TRAIN_DIR, NOW
28
+ from chemprop.cli.utils import (
29
+ LookupAction,
30
+ Subcommand,
31
+ build_data_from_files,
32
+ get_column_names,
33
+ make_dataset,
34
+ parse_indices,
35
+ )
36
+ from chemprop.cli.utils.args import uppercase
37
+ from chemprop.data import (
38
+ MoleculeDataset,
39
+ MolGraphDataset,
40
+ MulticomponentDataset,
41
+ ReactionDatapoint,
42
+ SplitType,
43
+ build_dataloader,
44
+ make_split_indices,
45
+ split_data_by_indices,
46
+ )
47
+ from chemprop.data.datasets import _MolGraphDatasetMixin
48
+ from chemprop.models import MPNN, MulticomponentMPNN, save_model
49
+ from chemprop.nn import AggregationRegistry, LossFunctionRegistry, MetricRegistry, PredictorRegistry
50
+ from chemprop.nn.message_passing import (
51
+ AtomMessagePassing,
52
+ BondMessagePassing,
53
+ MulticomponentMessagePassing,
54
+ )
55
+ from chemprop.nn.transforms import GraphTransform, ScaleTransform, UnscaleTransform
56
+ from chemprop.nn.utils import Activation
57
+ from chemprop.utils import Factory
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
62
+ _CV_REMOVAL_ERROR = (
63
+ "The -k/--num-folds argument was removed in v2.1.0 - use --num-replicates instead."
64
+ )
65
+
66
+
67
+ class TrainSubcommand(Subcommand):
68
+ COMMAND = "train"
69
+ HELP = "Train a chemprop model."
70
+ parser = None
71
+
72
+ @classmethod
73
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
74
+ parser = add_common_args(parser)
75
+ parser = add_train_args(parser)
76
+ cls.parser = parser
77
+ return parser
78
+
79
+ @classmethod
80
+ def func(cls, args: Namespace):
81
+ args = process_common_args(args)
82
+ validate_common_args(args)
83
+ args = process_train_args(args)
84
+ validate_train_args(args)
85
+
86
+ args.output_dir.mkdir(exist_ok=True, parents=True)
87
+ config_path = args.output_dir / "config.toml"
88
+ save_config(cls.parser, args, config_path)
89
+ main(args)
90
+
91
+
92
+ def add_train_args(parser: ArgumentParser) -> ArgumentParser:
93
+ parser.add_argument(
94
+ "--config-path",
95
+ type=Path,
96
+ is_config_file=True,
97
+ help="Path to a configuration file (command line arguments override values in the configuration file)",
98
+ )
99
+ parser.add_argument(
100
+ "-i",
101
+ "--data-path",
102
+ type=Path,
103
+ help="Path to an input CSV file containing SMILES and the associated target values",
104
+ )
105
+ parser.add_argument(
106
+ "-o",
107
+ "--output-dir",
108
+ "--save-dir",
109
+ type=Path,
110
+ help="Directory where training outputs will be saved (defaults to ``CURRENT_DIRECTORY/chemprop_training/STEM_OF_INPUT/TIME_STAMP``)",
111
+ )
112
+ parser.add_argument(
113
+ "--remove-checkpoints",
114
+ action="store_true",
115
+ help="Remove intermediate checkpoint files after training is complete.",
116
+ )
117
+
118
+ # TODO: Add in v2.1; see if we can tell lightning how often to log training loss
119
+ # parser.add_argument(
120
+ # "--log-frequency",
121
+ # type=int,
122
+ # default=10,
123
+ # help="The number of batches between each logging of the training loss.",
124
+ # )
125
+
126
+ transfer_args = parser.add_argument_group("transfer learning args")
127
+ transfer_args.add_argument(
128
+ "--checkpoint",
129
+ type=Path,
130
+ nargs="+",
131
+ help="Path to checkpoint(s) or model file(s) for loading and overwriting weights. Accepts a single pre-trained model checkpoint (.ckpt), a single model file (.pt), a directory containing such files, or a list of paths and directories. If a directory is provided, it will recursively search for and use all (.pt) files found for prediction.",
132
+ )
133
+ transfer_args.add_argument(
134
+ "--freeze-encoder",
135
+ action="store_true",
136
+ help="Freeze the message passing layer from the checkpoint model (specified by ``--checkpoint``).",
137
+ )
138
+ transfer_args.add_argument(
139
+ "--model-frzn",
140
+ help="Path to model checkpoint file to be loaded for overwriting and freezing weights. By default, all MPNN weights are frozen with this option.",
141
+ )
142
+ transfer_args.add_argument(
143
+ "--frzn-ffn-layers",
144
+ type=int,
145
+ default=0,
146
+ help="Freeze the first ``n`` layers of the FFN from the checkpoint model (specified by ``--checkpoint``). The message passing layer should also be frozen with ``--freeze-encoder``.",
147
+ )
148
+ # transfer_args.add_argument(
149
+ # "--freeze-first-only",
150
+ # action="store_true",
151
+ # help="Determines whether or not to use checkpoint_frzn for just the first encoder. Default (False) is to use the checkpoint to freeze all encoders. (only relevant for number_of_molecules > 1, where checkpoint model has number_of_molecules = 1)",
152
+ # )
153
+
154
+ # TODO: Add in v2.1
155
+ # parser.add_argument(
156
+ # "--resume-experiment",
157
+ # action="store_true",
158
+ # help="Whether to resume the experiment. Loads test results from any folds that have already been completed and skips training those folds.",
159
+ # )
160
+ # parser.add_argument(
161
+ # "--config-path",
162
+ # help="Path to a :code:`.json` file containing arguments. Any arguments present in the config file will override arguments specified via the command line or by the defaults.",
163
+ # )
164
+ parser.add_argument(
165
+ "--ensemble-size",
166
+ type=int,
167
+ default=1,
168
+ help="Number of models in ensemble for each splitting of data",
169
+ )
170
+
171
+ # TODO: Add in v2.2
172
+ # abt_args = parser.add_argument_group("atom/bond target args")
173
+ # abt_args.add_argument(
174
+ # "--is-atom-bond-targets",
175
+ # action="store_true",
176
+ # help="Whether this is atomic/bond properties prediction.",
177
+ # )
178
+ # abt_args.add_argument(
179
+ # "--no-adding-bond-types",
180
+ # action="store_true",
181
+ # help="Whether the bond types determined by RDKit molecules added to the output of bond targets. This option is intended to be used with the :code:`is_atom_bond_targets`.",
182
+ # )
183
+ # abt_args.add_argument(
184
+ # "--keeping-atom-map",
185
+ # action="store_true",
186
+ # help="Whether RDKit molecules keep the original atom mapping. This option is intended to be used when providing atom-mapped SMILES with the :code:`is_atom_bond_targets`.",
187
+ # )
188
+ # abt_args.add_argument(
189
+ # "--no-shared-atom-bond-ffn",
190
+ # action="store_true",
191
+ # help="Whether the FFN weights for atom and bond targets should be independent between tasks.",
192
+ # )
193
+ # abt_args.add_argument(
194
+ # "--weights-ffn-num-layers",
195
+ # type=int,
196
+ # default=2,
197
+ # help="Number of layers in FFN for determining weights used in constrained targets.",
198
+ # )
199
+
200
+ mp_args = parser.add_argument_group("message passing")
201
+ mp_args.add_argument(
202
+ "--message-hidden-dim", type=int, default=300, help="Hidden dimension of the messages"
203
+ )
204
+ mp_args.add_argument(
205
+ "--message-bias", action="store_true", help="Add bias to the message passing layers"
206
+ )
207
+ mp_args.add_argument("--depth", type=int, default=3, help="Number of message passing steps")
208
+ mp_args.add_argument(
209
+ "--undirected",
210
+ action="store_true",
211
+ help="Pass messages on undirected bonds/edges (always sum the two relevant bond vectors)",
212
+ )
213
+ mp_args.add_argument(
214
+ "--dropout",
215
+ type=float,
216
+ default=0.0,
217
+ help="Dropout probability in message passing/FFN layers",
218
+ )
219
+ mp_args.add_argument(
220
+ "--mpn-shared",
221
+ action="store_true",
222
+ help="Whether to use the same message passing neural network for all input molecules (only relevant if ``number_of_molecules`` > 1)",
223
+ )
224
+ mp_args.add_argument(
225
+ "--activation",
226
+ type=uppercase,
227
+ default="RELU",
228
+ choices=list(Activation.keys()),
229
+ help="Activation function in message passing/FFN layers",
230
+ )
231
+ mp_args.add_argument(
232
+ "--aggregation",
233
+ "--agg",
234
+ default="norm",
235
+ action=LookupAction(AggregationRegistry),
236
+ help="Aggregation mode to use during graph predictor",
237
+ )
238
+ mp_args.add_argument(
239
+ "--aggregation-norm",
240
+ type=float,
241
+ default=100,
242
+ help="Normalization factor by which to divide summed up atomic features for ``norm`` aggregation",
243
+ )
244
+ mp_args.add_argument(
245
+ "--atom-messages", action="store_true", help="Pass messages on atoms rather than bonds."
246
+ )
247
+
248
+ # TODO: Add in v2.1
249
+ # mpsolv_args = parser.add_argument_group("message passing with solvent")
250
+ # mpsolv_args.add_argument(
251
+ # "--reaction-solvent",
252
+ # action="store_true",
253
+ # help="Whether to adjust the MPNN layer to take as input a reaction and a molecule, and to encode them with separate MPNNs.",
254
+ # )
255
+ # mpsolv_args.add_argument(
256
+ # "--bias-solvent",
257
+ # action="store_true",
258
+ # help="Whether to add bias to linear layers for solvent MPN if :code:`reaction_solvent` is True.",
259
+ # )
260
+ # mpsolv_args.add_argument(
261
+ # "--hidden-size-solvent",
262
+ # type=int,
263
+ # default=300,
264
+ # help="Dimensionality of hidden layers in solvent MPN if :code:`reaction_solvent` is True.",
265
+ # )
266
+ # mpsolv_args.add_argument(
267
+ # "--depth-solvent",
268
+ # type=int,
269
+ # default=3,
270
+ # help="Number of message passing steps for solvent if :code:`reaction_solvent` is True.",
271
+ # )
272
+
273
+ ffn_args = parser.add_argument_group("FFN args")
274
+ ffn_args.add_argument(
275
+ "--ffn-hidden-dim", type=int, default=300, help="Hidden dimension in the FFN top model"
276
+ )
277
+ ffn_args.add_argument( # TODO: the default in v1 was 2. (see weights_ffn_num_layers option) Do we really want the default to now be 1?
278
+ "--ffn-num-layers", type=int, default=1, help="Number of layers in FFN top model"
279
+ )
280
+ # TODO: Decide if we want to implment this in v2
281
+ # ffn_args.add_argument(
282
+ # "--features-only",
283
+ # action="store_true",
284
+ # help="Use only the additional features in an FFN, no graph network.",
285
+ # )
286
+
287
+ extra_mpnn_args = parser.add_argument_group("extra MPNN args")
288
+ extra_mpnn_args.add_argument(
289
+ "--batch-norm", action="store_true", help="Turn on batch normalization after aggregation"
290
+ )
291
+ extra_mpnn_args.add_argument(
292
+ "--multiclass-num-classes",
293
+ type=int,
294
+ default=3,
295
+ help="Number of classes when running multiclass classification",
296
+ )
297
+ # TODO: Add in v2.1
298
+ # extra_mpnn_args.add_argument(
299
+ # "--spectral-activation",
300
+ # default="exp",
301
+ # choices=["softplus", "exp"],
302
+ # help="Indicates which function to use in task_type spectra training to constrain outputs to be positive.",
303
+ # )
304
+
305
+ train_data_args = parser.add_argument_group("training input data args")
306
+ train_data_args.add_argument(
307
+ "-w",
308
+ "--weight-column",
309
+ help="Name of the column in the input CSV containing individual data weights",
310
+ )
311
+ train_data_args.add_argument(
312
+ "--target-columns",
313
+ nargs="+",
314
+ help="Name of the columns containing target values (by default, uses all columns except the SMILES column and the ``ignore_columns``)",
315
+ )
316
+ train_data_args.add_argument(
317
+ "--ignore-columns",
318
+ nargs="+",
319
+ help="Name of the columns to ignore when ``target_columns`` is not provided",
320
+ )
321
+ train_data_args.add_argument(
322
+ "--no-cache",
323
+ action="store_true",
324
+ help="Turn off caching the featurized ``MolGraph`` s at the beginning of training",
325
+ )
326
+ train_data_args.add_argument(
327
+ "--splits-column",
328
+ help="Name of the column in the input CSV file containing 'train', 'val', or 'test' for each row.",
329
+ )
330
+ # TODO: Add in v2.1
331
+ # train_data_args.add_argument(
332
+ # "--spectra-phase-mask-path",
333
+ # help="Path to a file containing a phase mask array, used for excluding particular regions in spectra predictions.",
334
+ # )
335
+
336
+ train_args = parser.add_argument_group("training args")
337
+ train_args.add_argument(
338
+ "-t",
339
+ "--task-type",
340
+ default="regression",
341
+ action=LookupAction(PredictorRegistry),
342
+ help="Type of dataset (determines the default loss function used during training, defaults to ``regression``)",
343
+ )
344
+ train_args.add_argument(
345
+ "-l",
346
+ "--loss-function",
347
+ action=LookupAction(LossFunctionRegistry),
348
+ help="Loss function to use during training (will use the default loss function for the given task type if not specified)",
349
+ )
350
+ train_args.add_argument(
351
+ "--v-kl",
352
+ "--evidential-regularization",
353
+ type=float,
354
+ default=0.0,
355
+ help="Specify the value used in regularization for evidential loss function. The default value recommended by Soleimany et al. (2021) is 0.2. However, the optimal value is dataset-dependent, so it is recommended that users test different values to find the best value for their model.",
356
+ )
357
+
358
+ train_args.add_argument(
359
+ "--eps", type=float, default=1e-8, help="Evidential regularization epsilon"
360
+ )
361
+ train_args.add_argument(
362
+ "--alpha", type=float, default=0.1, help="Target error bounds for quantile interval loss"
363
+ )
364
+ # TODO: Add in v2.1
365
+ # train_args.add_argument( # TODO: Is threshold the same thing as the spectra target floor? I'm not sure but combined them.
366
+ # "-T",
367
+ # "--threshold",
368
+ # "--spectra-target-floor",
369
+ # type=float,
370
+ # default=1e-8,
371
+ # help="spectral threshold limit. v1 help string: Values in targets for dataset type spectra are replaced with this value, intended to be a small positive number used to enforce positive values.",
372
+ # )
373
+ train_args.add_argument(
374
+ "--metrics",
375
+ "--metric",
376
+ nargs="+",
377
+ action=LookupAction(MetricRegistry),
378
+ help="Specify the evaluation metrics. If unspecified, chemprop will use the following metrics for given dataset types: regression -> ``rmse``, classification -> ``roc``, multiclass -> ``ce`` ('cross entropy'), spectral -> ``sid``. If multiple metrics are provided, the 0-th one will be used for early stopping and checkpointing.",
379
+ )
380
+ train_args.add_argument(
381
+ "--tracking-metric",
382
+ default="val_loss",
383
+ help="The metric to track for early stopping and checkpointing. Defaults to the criterion used during training.",
384
+ )
385
+ train_args.add_argument(
386
+ "--show-individual-scores",
387
+ action="store_true",
388
+ help="Show all scores for individual targets, not just average, at the end.",
389
+ )
390
+ train_args.add_argument(
391
+ "--task-weights",
392
+ nargs="+",
393
+ type=float,
394
+ help="Weights to apply for whole tasks in the loss function",
395
+ )
396
+ train_args.add_argument(
397
+ "--warmup-epochs",
398
+ type=int,
399
+ default=2,
400
+ help="Number of epochs during which learning rate increases linearly from ``init_lr`` to ``max_lr`` (afterwards, learning rate decreases exponentially from ``max_lr`` to ``final_lr``)",
401
+ )
402
+
403
+ train_args.add_argument("--init-lr", type=float, default=1e-4, help="Initial learning rate.")
404
+ train_args.add_argument("--max-lr", type=float, default=1e-3, help="Maximum learning rate.")
405
+ train_args.add_argument("--final-lr", type=float, default=1e-4, help="Final learning rate.")
406
+ train_args.add_argument("--epochs", type=int, default=50, help="Number of epochs to train over")
407
+ train_args.add_argument(
408
+ "--patience",
409
+ type=int,
410
+ default=None,
411
+ help="Number of epochs to wait for improvement before early stopping",
412
+ )
413
+ train_args.add_argument(
414
+ "--grad-clip",
415
+ type=float,
416
+ help="Passed directly to the lightning trainer which controls grad clipping (see the ``Trainer()`` docstring for details)",
417
+ )
418
+ train_args.add_argument(
419
+ "--class-balance",
420
+ action="store_true",
421
+ help="Ensures each training batch contains an equal number of positive and negative samples.",
422
+ )
423
+
424
+ split_args = parser.add_argument_group("split args")
425
+ split_args.add_argument(
426
+ "--split",
427
+ "--split-type",
428
+ type=uppercase,
429
+ default="RANDOM",
430
+ choices=list(SplitType.keys()),
431
+ help="Method of splitting the data into train/val/test (case insensitive)",
432
+ )
433
+ split_args.add_argument(
434
+ "--split-sizes",
435
+ type=float,
436
+ nargs=3,
437
+ default=[0.8, 0.1, 0.1],
438
+ help="Split proportions for train/validation/test sets",
439
+ )
440
+ split_args.add_argument(
441
+ "--split-key-molecule",
442
+ type=int,
443
+ default=0,
444
+ help="Specify the index of the key molecule used for splitting when multiple molecules are present and constrained split_type is used (e.g., ``scaffold_balanced`` or ``random_with_repeated_smiles``). Note that this index begins with zero for the first molecule.",
445
+ )
446
+ split_args.add_argument("--num-replicates", type=int, default=1, help="Number of replicates.")
447
+ split_args.add_argument("-k", "--num-folds", help=_CV_REMOVAL_ERROR)
448
+ split_args.add_argument(
449
+ "--save-smiles-splits",
450
+ action="store_true",
451
+ help="Whether to store the SMILES in each train/val/test split",
452
+ )
453
+ split_args.add_argument(
454
+ "--splits-file",
455
+ type=Path,
456
+ help="Path to a JSON file containing pre-defined splits for the input data, formatted as a list of dictionaries with keys ``train``, ``val``, and ``test`` and values as lists of indices or formatted strings (e.g. [0, 1, 2, 4] or '0-2,4')",
457
+ )
458
+ split_args.add_argument(
459
+ "--data-seed",
460
+ type=int,
461
+ default=0,
462
+ help="Specify the random seed to use when splitting data into train/val/test sets. When ``--num-replicates`` > 1, the first replicate uses this seed and all subsequent replicates add 1 to the seed (also used for shuffling data in ``build_dataloader`` when ``shuffle`` is True).",
463
+ )
464
+
465
+ parser.add_argument(
466
+ "--pytorch-seed",
467
+ type=int,
468
+ default=None,
469
+ help="Seed for PyTorch randomness (e.g., random initial weights)",
470
+ )
471
+
472
+ return parser
473
+
474
+
475
+ def process_train_args(args: Namespace) -> Namespace:
476
+ if args.output_dir is None:
477
+ args.output_dir = CHEMPROP_TRAIN_DIR / args.data_path.stem / NOW
478
+
479
+ return args
480
+
481
+
482
+ def validate_train_args(args):
483
+ if args.config_path is None and args.data_path is None:
484
+ raise ArgumentError(argument=None, message="Data path must be provided for training.")
485
+
486
+ if args.num_folds is not None: # i.e. user-specified
487
+ raise ArgumentError(argument=None, message=_CV_REMOVAL_ERROR)
488
+
489
+ if args.data_path.suffix not in [".csv"]:
490
+ raise ArgumentError(
491
+ argument=None, message=f"Input data must be a CSV file. Got {args.data_path}"
492
+ )
493
+
494
+ if args.epochs != -1 and args.epochs <= args.warmup_epochs:
495
+ raise ArgumentError(
496
+ argument=None,
497
+ message=f"The number of epochs should be higher than the number of epochs during warmup. Got {args.epochs} epochs and {args.warmup_epochs} warmup epochs",
498
+ )
499
+
500
+ # TODO: model_frzn is deprecated and then remove in v2.2
501
+ if args.checkpoint is not None and args.model_frzn is not None:
502
+ raise ArgumentError(
503
+ argument=None,
504
+ message="`--checkpoint` and `--model-frzn` cannot be used at the same time.",
505
+ )
506
+
507
+ if "--model-frzn" in sys.argv:
508
+ logger.warning(
509
+ "`--model-frzn` is deprecated and will be removed in v2.2. "
510
+ "Please use `--checkpoint` with `--freeze-encoder` instead."
511
+ )
512
+
513
+ if args.freeze_encoder and args.checkpoint is None:
514
+ raise ArgumentError(
515
+ argument=None,
516
+ message="`--freeze-encoder` can only be used when `--checkpoint` is used.",
517
+ )
518
+
519
+ if args.frzn_ffn_layers > 0:
520
+ if args.checkpoint is None and args.model_frzn is None:
521
+ raise ArgumentError(
522
+ argument=None,
523
+ message="`--frzn-ffn-layers` can only be used when `--checkpoint` or `--model-frzn` (depreciated in v2.1) is used.",
524
+ )
525
+ if args.checkpoint is not None and not args.freeze_encoder:
526
+ raise ArgumentError(
527
+ argument=None,
528
+ message="To freeze the first `n` layers of the FFN via `--frzn-ffn-layers`. The message passing layer should also be frozen with `--freeze-encoder`.",
529
+ )
530
+
531
+ if args.class_balance and args.task_type != "classification":
532
+ raise ArgumentError(
533
+ argument=None, message="Class balance is only applicable for classification tasks."
534
+ )
535
+
536
+ valid_tracking_metrics = (
537
+ args.metrics or [PredictorRegistry[args.task_type]._T_default_metric.alias]
538
+ ) + ["val_loss"]
539
+ if args.tracking_metric not in valid_tracking_metrics:
540
+ raise ArgumentError(
541
+ argument=None,
542
+ message=f"Tracking metric must be one of {','.join(valid_tracking_metrics)}. "
543
+ f"Got {args.tracking_metric}. Additional tracking metric options can be specified with "
544
+ "the `--metrics` flag.",
545
+ )
546
+
547
+ input_cols, target_cols = get_column_names(
548
+ args.data_path,
549
+ args.smiles_columns,
550
+ args.reaction_columns,
551
+ args.target_columns,
552
+ args.ignore_columns,
553
+ args.splits_column,
554
+ args.weight_column,
555
+ args.no_header_row,
556
+ )
557
+
558
+ args.input_columns = input_cols
559
+ args.target_columns = target_cols
560
+
561
+ return args
562
+
563
+
564
+ def normalize_inputs(train_dset, val_dset, args):
565
+ multicomponent = isinstance(train_dset, MulticomponentDataset)
566
+ num_components = train_dset.n_components if multicomponent else 1
567
+
568
+ X_d_transform = None
569
+ V_f_transforms = [nn.Identity()] * num_components
570
+ E_f_transforms = [nn.Identity()] * num_components
571
+ V_d_transforms = [None] * num_components
572
+ graph_transforms = []
573
+
574
+ d_xd = train_dset.d_xd
575
+ d_vf = train_dset.d_vf
576
+ d_ef = train_dset.d_ef
577
+ d_vd = train_dset.d_vd
578
+
579
+ if d_xd > 0 and not args.no_descriptor_scaling:
580
+ scaler = train_dset.normalize_inputs("X_d")
581
+ val_dset.normalize_inputs("X_d", scaler)
582
+
583
+ scaler = scaler if not isinstance(scaler, list) else scaler[0]
584
+
585
+ if scaler is not None:
586
+ logger.info(
587
+ f"Descriptors: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
588
+ )
589
+ X_d_transform = ScaleTransform.from_standard_scaler(scaler)
590
+
591
+ if d_vf > 0 and not args.no_atom_feature_scaling:
592
+ scaler = train_dset.normalize_inputs("V_f")
593
+ val_dset.normalize_inputs("V_f", scaler)
594
+
595
+ scalers = [scaler] if not isinstance(scaler, list) else scaler
596
+
597
+ for i, scaler in enumerate(scalers):
598
+ if scaler is None:
599
+ continue
600
+
601
+ logger.info(
602
+ f"Atom features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
603
+ )
604
+ featurizer = (
605
+ train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
606
+ )
607
+ V_f_transforms[i] = ScaleTransform.from_standard_scaler(
608
+ scaler, pad=featurizer.atom_fdim - featurizer.extra_atom_fdim
609
+ )
610
+
611
+ if d_ef > 0 and not args.no_bond_feature_scaling:
612
+ scaler = train_dset.normalize_inputs("E_f")
613
+ val_dset.normalize_inputs("E_f", scaler)
614
+
615
+ scalers = [scaler] if not isinstance(scaler, list) else scaler
616
+
617
+ for i, scaler in enumerate(scalers):
618
+ if scaler is None:
619
+ continue
620
+
621
+ logger.info(
622
+ f"Bond features for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
623
+ )
624
+ featurizer = (
625
+ train_dset.datasets[i].featurizer if multicomponent else train_dset.featurizer
626
+ )
627
+ E_f_transforms[i] = ScaleTransform.from_standard_scaler(
628
+ scaler, pad=featurizer.bond_fdim - featurizer.extra_bond_fdim
629
+ )
630
+
631
+ for V_f_transform, E_f_transform in zip(V_f_transforms, E_f_transforms):
632
+ graph_transforms.append(GraphTransform(V_f_transform, E_f_transform))
633
+
634
+ if d_vd > 0 and not args.no_atom_descriptor_scaling:
635
+ scaler = train_dset.normalize_inputs("V_d")
636
+ val_dset.normalize_inputs("V_d", scaler)
637
+
638
+ scalers = [scaler] if not isinstance(scaler, list) else scaler
639
+
640
+ for i, scaler in enumerate(scalers):
641
+ if scaler is None:
642
+ continue
643
+
644
+ logger.info(
645
+ f"Atom descriptors for mol {i}: loc = {np.array2string(scaler.mean_, precision=3)}, scale = {np.array2string(scaler.scale_, precision=3)}"
646
+ )
647
+ V_d_transforms[i] = ScaleTransform.from_standard_scaler(scaler)
648
+
649
+ return X_d_transform, graph_transforms, V_d_transforms
650
+
651
+
652
+ def load_and_use_pretrained_model_scalers(model_path: Path, train_dset, val_dset) -> None:
653
+ if isinstance(train_dset, MulticomponentDataset):
654
+ _model = MulticomponentMPNN.load_from_file(model_path)
655
+ blocks = _model.message_passing.blocks
656
+ train_dsets = train_dset.datasets
657
+ val_dsets = val_dset.datasets
658
+ else:
659
+ _model = MPNN.load_from_file(model_path)
660
+ blocks = [_model.message_passing]
661
+ train_dsets = [train_dset]
662
+ val_dsets = [val_dset]
663
+
664
+ for i in range(len(blocks)):
665
+ if isinstance(_model.X_d_transform, ScaleTransform):
666
+ scaler = _model.X_d_transform.to_standard_scaler()
667
+ train_dsets[i].normalize_inputs("X_d", scaler)
668
+ val_dsets[i].normalize_inputs("X_d", scaler)
669
+
670
+ if isinstance(blocks[i].graph_transform, GraphTransform):
671
+ if isinstance(blocks[i].graph_transform.V_transform, ScaleTransform):
672
+ V_anti_pad = (
673
+ train_dsets[i].featurizer.atom_fdim - train_dsets[i].featurizer.extra_atom_fdim
674
+ )
675
+ scaler = blocks[i].graph_transform.V_transform.to_standard_scaler(
676
+ anti_pad=V_anti_pad
677
+ )
678
+ train_dsets[i].normalize_inputs("V_f", scaler)
679
+ val_dsets[i].normalize_inputs("V_f", scaler)
680
+ if isinstance(blocks[i].graph_transform.E_transform, ScaleTransform):
681
+ E_anti_pad = (
682
+ train_dsets[i].featurizer.bond_fdim - train_dsets[i].featurizer.extra_bond_fdim
683
+ )
684
+ scaler = blocks[i].graph_transform.E_transform.to_standard_scaler(
685
+ anti_pad=E_anti_pad
686
+ )
687
+ train_dsets[i].normalize_inputs("E_f", scaler)
688
+ val_dsets[i].normalize_inputs("E_f", scaler)
689
+
690
+ if isinstance(blocks[i].V_d_transform, ScaleTransform):
691
+ scaler = blocks[i].V_d_transform.to_standard_scaler()
692
+ train_dsets[i].normalize_inputs("V_d", scaler)
693
+ val_dsets[i].normalize_inputs("V_d", scaler)
694
+
695
+ if isinstance(_model.predictor.output_transform, UnscaleTransform):
696
+ scaler = _model.predictor.output_transform.to_standard_scaler()
697
+ train_dset.normalize_targets(scaler)
698
+ val_dset.normalize_targets(scaler)
699
+
700
+
701
+ def save_config(parser: ArgumentParser, args: Namespace, config_path: Path):
702
+ config_args = deepcopy(args)
703
+ for key, value in vars(config_args).items():
704
+ if isinstance(value, Path):
705
+ setattr(config_args, key, str(value))
706
+
707
+ for key in ["atom_features_path", "atom_descriptors_path", "bond_features_path"]:
708
+ if getattr(config_args, key) is not None:
709
+ for index, path in getattr(config_args, key).items():
710
+ getattr(config_args, key)[index] = str(path)
711
+
712
+ parser.write_config_file(parsed_namespace=config_args, output_file_paths=[str(config_path)])
713
+
714
+
715
+ def save_smiles_splits(args: Namespace, output_dir, train_dset, val_dset, test_dset):
716
+ match (args.smiles_columns, args.reaction_columns):
717
+ case [_, None]:
718
+ column_labels = deepcopy(args.smiles_columns)
719
+ case [None, _]:
720
+ column_labels = deepcopy(args.reaction_columns)
721
+ case _:
722
+ column_labels = deepcopy(args.smiles_columns)
723
+ column_labels.extend(args.reaction_columns)
724
+
725
+ train_smis = train_dset.names
726
+ df_train = pd.DataFrame(train_smis, columns=column_labels)
727
+ df_train.to_csv(output_dir / "train_smiles.csv", index=False)
728
+
729
+ val_smis = val_dset.names
730
+ df_val = pd.DataFrame(val_smis, columns=column_labels)
731
+ df_val.to_csv(output_dir / "val_smiles.csv", index=False)
732
+
733
+ if test_dset is not None:
734
+ test_smis = test_dset.names
735
+ df_test = pd.DataFrame(test_smis, columns=column_labels)
736
+ df_test.to_csv(output_dir / "test_smiles.csv", index=False)
737
+
738
+
739
+ def build_splits(args, format_kwargs, featurization_kwargs):
740
+ """build the train/val/test splits"""
741
+ logger.info(f"Pulling data from file: {args.data_path}")
742
+ all_data = build_data_from_files(
743
+ args.data_path,
744
+ p_descriptors=args.descriptors_path,
745
+ p_atom_feats=args.atom_features_path,
746
+ p_bond_feats=args.bond_features_path,
747
+ p_atom_descs=args.atom_descriptors_path,
748
+ **format_kwargs,
749
+ **featurization_kwargs,
750
+ )
751
+
752
+ if args.splits_column is not None:
753
+ df = pd.read_csv(
754
+ args.data_path, header=None if args.no_header_row else "infer", index_col=False
755
+ )
756
+ grouped = df.groupby(df[args.splits_column].str.lower())
757
+ train_indices = grouped.groups.get("train", pd.Index([])).tolist()
758
+ val_indices = grouped.groups.get("val", pd.Index([])).tolist()
759
+ test_indices = grouped.groups.get("test", pd.Index([])).tolist()
760
+ train_indices, val_indices, test_indices = [train_indices], [val_indices], [test_indices]
761
+
762
+ elif args.splits_file is not None:
763
+ with open(args.splits_file, "rb") as json_file:
764
+ split_idxss = json.load(json_file)
765
+ train_indices = [parse_indices(d["train"]) for d in split_idxss]
766
+ val_indices = [parse_indices(d["val"]) for d in split_idxss]
767
+ test_indices = [parse_indices(d["test"]) for d in split_idxss]
768
+ args.num_replicates = len(split_idxss)
769
+
770
+ else:
771
+ splitting_data = all_data[args.split_key_molecule]
772
+ if isinstance(splitting_data[0], ReactionDatapoint):
773
+ splitting_mols = [datapoint.rct for datapoint in splitting_data]
774
+ else:
775
+ splitting_mols = [datapoint.mol for datapoint in splitting_data]
776
+ train_indices, val_indices, test_indices = make_split_indices(
777
+ splitting_mols, args.split, args.split_sizes, args.data_seed, args.num_replicates
778
+ )
779
+
780
+ train_data, val_data, test_data = split_data_by_indices(
781
+ all_data, train_indices, val_indices, test_indices
782
+ )
783
+ for i_split in range(len(train_data)):
784
+ sizes = [len(train_data[i_split][0]), len(val_data[i_split][0]), len(test_data[i_split][0])]
785
+ logger.info(f"train/val/test split_{i_split} sizes: {sizes}")
786
+
787
+ return train_data, val_data, test_data
788
+
789
+
790
+ def summarize(
791
+ target_cols: list[str], task_type: str, dataset: _MolGraphDatasetMixin
792
+ ) -> tuple[list, list]:
793
+ if task_type in [
794
+ "regression",
795
+ "regression-mve",
796
+ "regression-evidential",
797
+ "regression-quantile",
798
+ ]:
799
+ if isinstance(dataset, MulticomponentDataset):
800
+ y = dataset.datasets[0].Y
801
+ else:
802
+ y = dataset.Y
803
+ y_mean = np.nanmean(y, axis=0)
804
+ y_std = np.nanstd(y, axis=0)
805
+ y_median = np.nanmedian(y, axis=0)
806
+ mean_dev_abs = np.abs(y - y_mean)
807
+ num_targets = np.sum(~np.isnan(y), axis=0)
808
+ frac_1_sigma = np.sum((mean_dev_abs < y_std), axis=0) / num_targets
809
+ frac_2_sigma = np.sum((mean_dev_abs < 2 * y_std), axis=0) / num_targets
810
+
811
+ column_headers = ["Statistic"] + [f"Value ({target_cols[i]})" for i in range(y.shape[1])]
812
+ table_rows = [
813
+ ["Num. smiles"] + [f"{len(y)}" for i in range(y.shape[1])],
814
+ ["Num. targets"] + [f"{num_targets[i]}" for i in range(y.shape[1])],
815
+ ["Num. NaN"] + [f"{len(y) - num_targets[i]}" for i in range(y.shape[1])],
816
+ ["Mean"] + [f"{mean:0.3g}" for mean in y_mean],
817
+ ["Std. dev."] + [f"{std:0.3g}" for std in y_std],
818
+ ["Median"] + [f"{median:0.3g}" for median in y_median],
819
+ ["% within 1 s.d."] + [f"{sigma:0.0%}" for sigma in frac_1_sigma],
820
+ ["% within 2 s.d."] + [f"{sigma:0.0%}" for sigma in frac_2_sigma],
821
+ ]
822
+ return (column_headers, table_rows)
823
+ elif task_type in [
824
+ "classification",
825
+ "classification-dirichlet",
826
+ "multiclass",
827
+ "multiclass-dirichlet",
828
+ ]:
829
+ if isinstance(dataset, MulticomponentDataset):
830
+ y = dataset.datasets[0].Y
831
+ else:
832
+ y = dataset.Y
833
+
834
+ mask = np.isnan(y)
835
+ classes = np.sort(np.unique(y[~mask]))
836
+
837
+ class_counts = np.stack([(classes[:, None] == y[:, i]).sum(1) for i in range(y.shape[1])])
838
+ class_fracs = class_counts / y.shape[0]
839
+ nan_count = np.nansum(mask, axis=0)
840
+ nan_frac = nan_count / y.shape[0]
841
+
842
+ column_headers = ["Class"] + [f"Count/Percent {target_cols[i]}" for i in range(y.shape[1])]
843
+
844
+ table_rows = [
845
+ [f"{k}"] + [f"{class_counts[j, i]}/{class_fracs[j, i]:0.0%}" for j in range(y.shape[1])]
846
+ for i, k in enumerate(classes)
847
+ ]
848
+
849
+ nan_row = ["NaN"] + [f"{nan_count[i]}/{nan_frac[i]:0.0%}" for i in range(y.shape[1])]
850
+ table_rows.append(nan_row)
851
+
852
+ total_row = ["Total"] + [f"{y.shape[0]}/{100.00}%" for i in range(y.shape[1])]
853
+ table_rows.append(total_row)
854
+
855
+ return (column_headers, table_rows)
856
+ else:
857
+ raise ValueError(f"unsupported task type! Task type '{task_type}' was not recognized.")
858
+
859
+
860
+ def build_table(column_headers: list[str], table_rows: list[str], title: str | None = None) -> str:
861
+ right_justified_columns = [
862
+ Column(header=column_header, justify="right") for column_header in column_headers
863
+ ]
864
+ table = Table(*right_justified_columns, title=title)
865
+ for row in table_rows:
866
+ table.add_row(*row)
867
+
868
+ console = Console(record=True, file=StringIO(), width=200)
869
+ console.print(table)
870
+ return console.export_text()
871
+
872
+
873
+ def build_datasets(args, train_data, val_data, test_data):
874
+ """build the train/val/test datasets, where :attr:`test_data` may be None"""
875
+ multicomponent = len(train_data) > 1
876
+ if multicomponent:
877
+ train_dsets = [
878
+ make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
879
+ for data in train_data
880
+ ]
881
+ val_dsets = [
882
+ make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
883
+ for data in val_data
884
+ ]
885
+ train_dset = MulticomponentDataset(train_dsets)
886
+ val_dset = MulticomponentDataset(val_dsets)
887
+ if len(test_data[0]) > 0:
888
+ test_dsets = [
889
+ make_dataset(data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
890
+ for data in test_data
891
+ ]
892
+ test_dset = MulticomponentDataset(test_dsets)
893
+ else:
894
+ test_dset = None
895
+ else:
896
+ train_data = train_data[0]
897
+ val_data = val_data[0]
898
+ test_data = test_data[0]
899
+ train_dset = make_dataset(train_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
900
+ val_dset = make_dataset(val_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
901
+ if len(test_data) > 0:
902
+ test_dset = make_dataset(test_data, args.rxn_mode, args.multi_hot_atom_featurizer_mode)
903
+ else:
904
+ test_dset = None
905
+ if args.task_type != "spectral":
906
+ for dataset, label in zip(
907
+ [train_dset, val_dset, test_dset], ["Training", "Validation", "Test"]
908
+ ):
909
+ column_headers, table_rows = summarize(args.target_columns, args.task_type, dataset)
910
+ output = build_table(column_headers, table_rows, f"Summary of {label} Data")
911
+ logger.info("\n" + output)
912
+
913
+ return train_dset, val_dset, test_dset
914
+
915
+
916
+ def build_model(
917
+ args,
918
+ train_dset: MolGraphDataset | MulticomponentDataset,
919
+ output_transform: UnscaleTransform,
920
+ input_transforms: tuple[ScaleTransform, list[GraphTransform], list[ScaleTransform]],
921
+ ) -> MPNN:
922
+ mp_cls = AtomMessagePassing if args.atom_messages else BondMessagePassing
923
+
924
+ X_d_transform, graph_transforms, V_d_transforms = input_transforms
925
+ if isinstance(train_dset, MulticomponentDataset):
926
+ mp_blocks = [
927
+ mp_cls(
928
+ train_dset.datasets[i].featurizer.atom_fdim,
929
+ train_dset.datasets[i].featurizer.bond_fdim,
930
+ d_h=args.message_hidden_dim,
931
+ d_vd=(
932
+ train_dset.datasets[i].d_vd
933
+ if isinstance(train_dset.datasets[i], MoleculeDataset)
934
+ else 0
935
+ ),
936
+ bias=args.message_bias,
937
+ depth=args.depth,
938
+ undirected=args.undirected,
939
+ dropout=args.dropout,
940
+ activation=args.activation,
941
+ V_d_transform=V_d_transforms[i],
942
+ graph_transform=graph_transforms[i],
943
+ )
944
+ for i in range(train_dset.n_components)
945
+ ]
946
+ if args.mpn_shared:
947
+ if args.reaction_columns is not None and args.smiles_columns is not None:
948
+ raise ArgumentError(
949
+ argument=None,
950
+ message="Cannot use shared MPNN with both molecule and reaction data.",
951
+ )
952
+
953
+ mp_block = MulticomponentMessagePassing(mp_blocks, train_dset.n_components, args.mpn_shared)
954
+ # NOTE(degraff): this if/else block should be handled by the init of MulticomponentMessagePassing
955
+ # if args.mpn_shared:
956
+ # mp_block = MulticomponentMessagePassing(mp_blocks[0], n_components, args.mpn_shared)
957
+ # else:
958
+ d_xd = train_dset.datasets[0].d_xd
959
+ n_tasks = train_dset.datasets[0].Y.shape[1]
960
+ mpnn_cls = MulticomponentMPNN
961
+ else:
962
+ mp_block = mp_cls(
963
+ train_dset.featurizer.atom_fdim,
964
+ train_dset.featurizer.bond_fdim,
965
+ d_h=args.message_hidden_dim,
966
+ d_vd=train_dset.d_vd if isinstance(train_dset, MoleculeDataset) else 0,
967
+ bias=args.message_bias,
968
+ depth=args.depth,
969
+ undirected=args.undirected,
970
+ dropout=args.dropout,
971
+ activation=args.activation,
972
+ V_d_transform=V_d_transforms[0],
973
+ graph_transform=graph_transforms[0],
974
+ )
975
+ d_xd = train_dset.d_xd
976
+ n_tasks = train_dset.Y.shape[1]
977
+ mpnn_cls = MPNN
978
+
979
+ agg = Factory.build(AggregationRegistry[args.aggregation], norm=args.aggregation_norm)
980
+ predictor_cls = PredictorRegistry[args.task_type]
981
+ if args.loss_function is not None:
982
+ task_weights = torch.ones(n_tasks) if args.task_weights is None else args.task_weights
983
+ criterion = Factory.build(
984
+ LossFunctionRegistry[args.loss_function],
985
+ task_weights=task_weights,
986
+ v_kl=args.v_kl,
987
+ # threshold=args.threshold, TODO: Add in v2.1
988
+ eps=args.eps,
989
+ alpha=args.alpha,
990
+ )
991
+ else:
992
+ criterion = None
993
+ if args.metrics is not None:
994
+ metrics = [Factory.build(MetricRegistry[metric]) for metric in args.metrics]
995
+ else:
996
+ metrics = None
997
+
998
+ predictor = Factory.build(
999
+ predictor_cls,
1000
+ input_dim=mp_block.output_dim + d_xd,
1001
+ n_tasks=n_tasks,
1002
+ hidden_dim=args.ffn_hidden_dim,
1003
+ n_layers=args.ffn_num_layers,
1004
+ dropout=args.dropout,
1005
+ activation=args.activation,
1006
+ criterion=criterion,
1007
+ task_weights=args.task_weights,
1008
+ n_classes=args.multiclass_num_classes,
1009
+ output_transform=output_transform,
1010
+ # spectral_activation=args.spectral_activation, TODO: Add in v2.1
1011
+ )
1012
+
1013
+ if args.loss_function is None:
1014
+ logger.info(
1015
+ f"No loss function was specified! Using class default: {predictor_cls._T_default_criterion}"
1016
+ )
1017
+
1018
+ return mpnn_cls(
1019
+ mp_block,
1020
+ agg,
1021
+ predictor,
1022
+ args.batch_norm,
1023
+ metrics,
1024
+ args.warmup_epochs,
1025
+ args.init_lr,
1026
+ args.max_lr,
1027
+ args.final_lr,
1028
+ X_d_transform=X_d_transform,
1029
+ )
1030
+
1031
+
1032
+ def train_model(
1033
+ args, train_loader, val_loader, test_loader, output_dir, output_transform, input_transforms
1034
+ ):
1035
+ if args.checkpoint is not None:
1036
+ model_paths = find_models(args.checkpoint)
1037
+ if args.ensemble_size != len(model_paths):
1038
+ logger.warning(
1039
+ f"The number of models in ensemble for each splitting of data is set to {len(model_paths)}."
1040
+ )
1041
+ args.ensemble_size = len(model_paths)
1042
+
1043
+ for model_idx in range(args.ensemble_size):
1044
+ model_output_dir = output_dir / f"model_{model_idx}"
1045
+ model_output_dir.mkdir(exist_ok=True, parents=True)
1046
+
1047
+ if args.pytorch_seed is None:
1048
+ seed = torch.seed()
1049
+ deterministic = False
1050
+ else:
1051
+ seed = args.pytorch_seed + model_idx
1052
+ deterministic = True
1053
+
1054
+ torch.manual_seed(seed)
1055
+
1056
+ if args.checkpoint or args.model_frzn is not None:
1057
+ mpnn_cls = (
1058
+ MulticomponentMPNN
1059
+ if isinstance(train_loader.dataset, MulticomponentDataset)
1060
+ else MPNN
1061
+ )
1062
+ model_path = model_paths[model_idx] if args.checkpoint else args.model_frzn
1063
+ model = mpnn_cls.load_from_file(model_path)
1064
+
1065
+ if args.checkpoint:
1066
+ model.apply(
1067
+ lambda m: setattr(m, "p", args.dropout)
1068
+ if isinstance(m, torch.nn.Dropout)
1069
+ else None
1070
+ )
1071
+
1072
+ # TODO: model_frzn is deprecated and then remove in v2.2
1073
+ if args.model_frzn or args.freeze_encoder:
1074
+ model.message_passing.apply(lambda module: module.requires_grad_(False))
1075
+ model.message_passing.eval()
1076
+ model.bn.apply(lambda module: module.requires_grad_(False))
1077
+ model.bn.eval()
1078
+ for idx in range(args.frzn_ffn_layers):
1079
+ model.predictor.ffn[idx].requires_grad_(False)
1080
+ model.predictor.ffn[idx + 1].eval()
1081
+ else:
1082
+ model = build_model(args, train_loader.dataset, output_transform, input_transforms)
1083
+ logger.info(model)
1084
+
1085
+ try:
1086
+ trainer_logger = TensorBoardLogger(
1087
+ model_output_dir, "trainer_logs", default_hp_metric=False
1088
+ )
1089
+ except ModuleNotFoundError as e:
1090
+ logger.warning(
1091
+ f"Unable to import TensorBoardLogger, reverting to CSVLogger (original error: {e})."
1092
+ )
1093
+ trainer_logger = CSVLogger(model_output_dir, "trainer_logs")
1094
+
1095
+ if args.tracking_metric == "val_loss":
1096
+ T_tracking_metric = model.criterion.__class__
1097
+ tracking_metric = args.tracking_metric
1098
+ else:
1099
+ T_tracking_metric = MetricRegistry[args.tracking_metric]
1100
+ tracking_metric = "val/" + args.tracking_metric
1101
+
1102
+ monitor_mode = "max" if T_tracking_metric.higher_is_better else "min"
1103
+ logger.debug(f"Evaluation metric: '{T_tracking_metric.alias}', mode: '{monitor_mode}'")
1104
+
1105
+ if args.remove_checkpoints:
1106
+ temp_dir = TemporaryDirectory()
1107
+ checkpoint_dir = Path(temp_dir.name)
1108
+ else:
1109
+ checkpoint_dir = model_output_dir
1110
+
1111
+ checkpoint_filename = (
1112
+ f"best-epoch={{epoch}}-{tracking_metric.replace('/', '_')}="
1113
+ f"{{{tracking_metric}:.2f}}"
1114
+ )
1115
+ checkpointing = ModelCheckpoint(
1116
+ checkpoint_dir / "checkpoints",
1117
+ checkpoint_filename,
1118
+ tracking_metric,
1119
+ mode=monitor_mode,
1120
+ save_last=True,
1121
+ auto_insert_metric_name=False,
1122
+ )
1123
+
1124
+ if args.epochs != -1:
1125
+ patience = args.patience if args.patience is not None else args.epochs
1126
+ early_stopping = EarlyStopping(tracking_metric, patience=patience, mode=monitor_mode)
1127
+ callbacks = [checkpointing, early_stopping]
1128
+ else:
1129
+ callbacks = [checkpointing]
1130
+
1131
+ trainer = pl.Trainer(
1132
+ logger=trainer_logger,
1133
+ enable_progress_bar=True,
1134
+ accelerator=args.accelerator,
1135
+ devices=args.devices,
1136
+ max_epochs=args.epochs,
1137
+ callbacks=callbacks,
1138
+ gradient_clip_val=args.grad_clip,
1139
+ deterministic=deterministic,
1140
+ )
1141
+ trainer.fit(model, train_loader, val_loader)
1142
+
1143
+ if test_loader is not None:
1144
+ if isinstance(trainer.strategy, DDPStrategy):
1145
+ torch.distributed.destroy_process_group()
1146
+
1147
+ best_ckpt_path = trainer.checkpoint_callback.best_model_path
1148
+ trainer = pl.Trainer(
1149
+ logger=trainer_logger,
1150
+ enable_progress_bar=True,
1151
+ accelerator=args.accelerator,
1152
+ devices=1,
1153
+ )
1154
+ model = model.load_from_checkpoint(best_ckpt_path)
1155
+ predss = trainer.predict(model, dataloaders=test_loader)
1156
+ else:
1157
+ predss = trainer.predict(dataloaders=test_loader)
1158
+
1159
+ preds = torch.concat(predss, 0)
1160
+ if model.predictor.n_targets > 1:
1161
+ preds = preds[..., 0]
1162
+ preds = preds.numpy()
1163
+
1164
+ evaluate_and_save_predictions(
1165
+ preds, test_loader, model.metrics[:-1], model_output_dir, args
1166
+ )
1167
+
1168
+ best_model_path = checkpointing.best_model_path
1169
+ model = model.__class__.load_from_checkpoint(best_model_path)
1170
+ p_model = model_output_dir / "best.pt"
1171
+ save_model(p_model, model, args.target_columns)
1172
+ logger.info(f"Best model saved to '{p_model}'")
1173
+
1174
+ if args.remove_checkpoints:
1175
+ temp_dir.cleanup()
1176
+
1177
+
1178
+ def evaluate_and_save_predictions(preds, test_loader, metrics, model_output_dir, args):
1179
+ if isinstance(test_loader.dataset, MulticomponentDataset):
1180
+ test_dset = test_loader.dataset.datasets[0]
1181
+ else:
1182
+ test_dset = test_loader.dataset
1183
+ targets = test_dset.Y
1184
+ mask = torch.from_numpy(np.isfinite(targets))
1185
+ targets = np.nan_to_num(targets, nan=0.0)
1186
+ weights = torch.ones(len(test_dset))
1187
+ lt_mask = torch.from_numpy(test_dset.lt_mask) if test_dset.lt_mask[0] is not None else None
1188
+ gt_mask = torch.from_numpy(test_dset.gt_mask) if test_dset.gt_mask[0] is not None else None
1189
+
1190
+ individual_scores = dict()
1191
+ for metric in metrics:
1192
+ individual_scores[metric.alias] = []
1193
+ for i, col in enumerate(args.target_columns):
1194
+ if "multiclass" in args.task_type:
1195
+ preds_slice = torch.from_numpy(preds[:, i : i + 1, :])
1196
+ targets_slice = torch.from_numpy(targets[:, i : i + 1])
1197
+ else:
1198
+ preds_slice = torch.from_numpy(preds[:, i : i + 1])
1199
+ targets_slice = torch.from_numpy(targets[:, i : i + 1])
1200
+ preds_loss = metric(
1201
+ preds_slice,
1202
+ targets_slice,
1203
+ mask[:, i : i + 1],
1204
+ weights,
1205
+ lt_mask[:, i] if lt_mask is not None else None,
1206
+ gt_mask[:, i] if gt_mask is not None else None,
1207
+ )
1208
+ individual_scores[metric.alias].append(preds_loss)
1209
+
1210
+ logger.info("Test Set results:")
1211
+ for metric in metrics:
1212
+ avg_loss = sum(individual_scores[metric.alias]) / len(individual_scores[metric.alias])
1213
+ logger.info(f"test/{metric.alias}: {avg_loss}")
1214
+
1215
+ if args.show_individual_scores:
1216
+ logger.info("Entire Test Set individual results:")
1217
+ for metric in metrics:
1218
+ for i, col in enumerate(args.target_columns):
1219
+ logger.info(f"test/{col}/{metric.alias}: {individual_scores[metric.alias][i]}")
1220
+
1221
+ names = test_loader.dataset.names
1222
+ if isinstance(test_loader.dataset, MulticomponentDataset):
1223
+ namess = list(zip(*names))
1224
+ else:
1225
+ namess = [names]
1226
+
1227
+ columns = args.input_columns + args.target_columns
1228
+ if "multiclass" in args.task_type:
1229
+ columns = columns + [f"{col}_prob" for col in args.target_columns]
1230
+ formatted_probability_strings = np.apply_along_axis(
1231
+ lambda x: ",".join(map(str, x)), 2, preds
1232
+ )
1233
+ predicted_class_labels = preds.argmax(axis=-1)
1234
+ df_preds = pd.DataFrame(
1235
+ list(zip(*namess, *predicted_class_labels.T, *formatted_probability_strings.T)),
1236
+ columns=columns,
1237
+ )
1238
+ else:
1239
+ df_preds = pd.DataFrame(list(zip(*namess, *preds.T)), columns=columns)
1240
+ df_preds.to_csv(model_output_dir / "test_predictions.csv", index=False)
1241
+
1242
+
1243
+ def main(args):
1244
+ format_kwargs = dict(
1245
+ no_header_row=args.no_header_row,
1246
+ smiles_cols=args.smiles_columns,
1247
+ rxn_cols=args.reaction_columns,
1248
+ target_cols=args.target_columns,
1249
+ ignore_cols=args.ignore_columns,
1250
+ splits_col=args.splits_column,
1251
+ weight_col=args.weight_column,
1252
+ bounded=args.loss_function is not None and "bounded" in args.loss_function,
1253
+ )
1254
+
1255
+ featurization_kwargs = dict(
1256
+ molecule_featurizers=args.molecule_featurizers, keep_h=args.keep_h, add_h=args.add_h
1257
+ )
1258
+
1259
+ splits = build_splits(args, format_kwargs, featurization_kwargs)
1260
+
1261
+ for replicate_idx, (train_data, val_data, test_data) in enumerate(zip(*splits)):
1262
+ if args.num_replicates == 1:
1263
+ output_dir = args.output_dir
1264
+ else:
1265
+ output_dir = args.output_dir / f"replicate_{replicate_idx}"
1266
+
1267
+ output_dir.mkdir(exist_ok=True, parents=True)
1268
+
1269
+ train_dset, val_dset, test_dset = build_datasets(args, train_data, val_data, test_data)
1270
+
1271
+ if args.save_smiles_splits:
1272
+ save_smiles_splits(args, output_dir, train_dset, val_dset, test_dset)
1273
+
1274
+ if args.checkpoint or args.model_frzn is not None:
1275
+ model_paths = find_models(args.checkpoint)
1276
+ if len(model_paths) > 1:
1277
+ logger.warning(
1278
+ "Multiple checkpoint files were loaded, but only the scalers from "
1279
+ f"{model_paths[0]} are used. It is assumed that all models provided have the "
1280
+ "same data scalings, meaning they were trained on the same data."
1281
+ )
1282
+ model_path = model_paths[0] if args.checkpoint else args.model_frzn
1283
+ load_and_use_pretrained_model_scalers(model_path, train_dset, val_dset)
1284
+ input_transforms = (None, None, None)
1285
+ output_transform = None
1286
+ else:
1287
+ input_transforms = normalize_inputs(train_dset, val_dset, args)
1288
+
1289
+ if "regression" in args.task_type:
1290
+ output_scaler = train_dset.normalize_targets()
1291
+ val_dset.normalize_targets(output_scaler)
1292
+ logger.info(
1293
+ f"Train data: mean = {output_scaler.mean_} | std = {output_scaler.scale_}"
1294
+ )
1295
+ output_transform = UnscaleTransform.from_standard_scaler(output_scaler)
1296
+ else:
1297
+ output_transform = None
1298
+
1299
+ if not args.no_cache:
1300
+ train_dset.cache = True
1301
+ val_dset.cache = True
1302
+
1303
+ train_loader = build_dataloader(
1304
+ train_dset,
1305
+ args.batch_size,
1306
+ args.num_workers,
1307
+ class_balance=args.class_balance,
1308
+ seed=args.data_seed,
1309
+ )
1310
+ if args.class_balance:
1311
+ logger.debug(
1312
+ f"With `--class-balance`, effective train size = {len(train_loader.sampler)}"
1313
+ )
1314
+ val_loader = build_dataloader(val_dset, args.batch_size, args.num_workers, shuffle=False)
1315
+ if test_dset is not None:
1316
+ test_loader = build_dataloader(
1317
+ test_dset, args.batch_size, args.num_workers, shuffle=False
1318
+ )
1319
+ else:
1320
+ test_loader = None
1321
+
1322
+ train_model(
1323
+ args,
1324
+ train_loader,
1325
+ val_loader,
1326
+ test_loader,
1327
+ output_dir,
1328
+ output_transform,
1329
+ input_transforms,
1330
+ )
1331
+
1332
+
1333
+ if __name__ == "__main__":
1334
+ # TODO: update this old code or remove it.
1335
+ parser = ArgumentParser()
1336
+ parser = TrainSubcommand.add_args(parser)
1337
+
1338
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG, force=True)
1339
+ args = parser.parse_args()
1340
+ TrainSubcommand.func(args)
chemprop-updated/chemprop/cli/utils/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .actions import LookupAction
2
+ from .args import bounded
3
+ from .command import Subcommand
4
+ from .parsing import (
5
+ build_data_from_files,
6
+ get_column_names,
7
+ make_datapoints,
8
+ make_dataset,
9
+ parse_indices,
10
+ )
11
+ from .utils import _pop_attr, _pop_attr_d, pop_attr
12
+
13
+ __all__ = [
14
+ "bounded",
15
+ "LookupAction",
16
+ "Subcommand",
17
+ "build_data_from_files",
18
+ "make_datapoints",
19
+ "make_dataset",
20
+ "get_column_names",
21
+ "parse_indices",
22
+ "actions",
23
+ "args",
24
+ "command",
25
+ "parsing",
26
+ "utils",
27
+ "pop_attr",
28
+ "_pop_attr",
29
+ "_pop_attr_d",
30
+ ]
chemprop-updated/chemprop/cli/utils/actions.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import _StoreAction
2
+ from typing import Any, Mapping
3
+
4
+
5
+ def LookupAction(obj: Mapping[str, Any]):
6
+ class LookupAction_(_StoreAction):
7
+ def __init__(self, option_strings, dest, default=None, choices=None, **kwargs):
8
+ if default not in obj.keys() and default is not None:
9
+ raise ValueError(
10
+ f"Invalid value for arg 'default': '{default}'. "
11
+ f"Expected one of {tuple(obj.keys())}"
12
+ )
13
+
14
+ kwargs["choices"] = choices if choices is not None else obj.keys()
15
+ kwargs["default"] = default
16
+
17
+ super().__init__(option_strings, dest, **kwargs)
18
+
19
+ return LookupAction_
chemprop-updated/chemprop/cli/utils/args.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ __all__ = ["bounded"]
4
+
5
+
6
+ def bounded(lo: float | None = None, hi: float | None = None):
7
+ if lo is None and hi is None:
8
+ raise ValueError("No bounds provided!")
9
+
10
+ def decorator(f):
11
+ @functools.wraps(f)
12
+ def wrapper(*args, **kwargs):
13
+ x = f(*args, **kwargs)
14
+
15
+ if (lo is not None and hi is not None) and not lo <= x <= hi:
16
+ raise ValueError(f"Parsed value outside of range [{lo}, {hi}]! got: {x}")
17
+ if hi is not None and x > hi:
18
+ raise ValueError(f"Parsed value below {hi}! got: {x}")
19
+ if lo is not None and x < lo:
20
+ raise ValueError(f"Parsed value above {lo}]! got: {x}")
21
+
22
+ return x
23
+
24
+ return wrapper
25
+
26
+ return decorator
27
+
28
+
29
+ def uppercase(x: str):
30
+ return x.upper()
31
+
32
+
33
+ def lowercase(x: str):
34
+ return x.lower()
chemprop-updated/chemprop/cli/utils/command.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from argparse import ArgumentParser, Namespace, _SubParsersAction
3
+
4
+
5
+ class Subcommand(ABC):
6
+ COMMAND: str
7
+ HELP: str | None = None
8
+
9
+ @classmethod
10
+ def add(cls, subparsers: _SubParsersAction, parents) -> ArgumentParser:
11
+ parser = subparsers.add_parser(cls.COMMAND, help=cls.HELP, parents=parents)
12
+ cls.add_args(parser).set_defaults(func=cls.func)
13
+
14
+ return parser
15
+
16
+ @classmethod
17
+ @abstractmethod
18
+ def add_args(cls, parser: ArgumentParser) -> ArgumentParser:
19
+ pass
20
+
21
+ @classmethod
22
+ @abstractmethod
23
+ def func(cls, args: Namespace):
24
+ pass
chemprop-updated/chemprop/cli/utils/parsing.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from os import PathLike
3
+ from typing import Literal, Mapping, Sequence
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+
8
+ from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
9
+ from chemprop.data.datasets import MoleculeDataset, ReactionDataset
10
+ from chemprop.featurizers.atom import get_multi_hot_atom_featurizer
11
+ from chemprop.featurizers.bond import MultiHotBondFeaturizer, RIGRBondFeaturizer
12
+ from chemprop.featurizers.molecule import MoleculeFeaturizerRegistry
13
+ from chemprop.featurizers.molgraph import (
14
+ CondensedGraphOfReactionFeaturizer,
15
+ SimpleMoleculeMolGraphFeaturizer,
16
+ )
17
+ from chemprop.utils import make_mol
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def parse_csv(
23
+ path: PathLike,
24
+ smiles_cols: Sequence[str] | None,
25
+ rxn_cols: Sequence[str] | None,
26
+ target_cols: Sequence[str] | None,
27
+ ignore_cols: Sequence[str] | None,
28
+ splits_col: str | None,
29
+ weight_col: str | None,
30
+ bounded: bool = False,
31
+ no_header_row: bool = False,
32
+ ):
33
+ df = pd.read_csv(path, header=None if no_header_row else "infer", index_col=False)
34
+
35
+ if smiles_cols is not None and rxn_cols is not None:
36
+ smiss = df[smiles_cols].T.values.tolist()
37
+ rxnss = df[rxn_cols].T.values.tolist()
38
+ input_cols = [*smiles_cols, *rxn_cols]
39
+ elif smiles_cols is not None and rxn_cols is None:
40
+ smiss = df[smiles_cols].T.values.tolist()
41
+ rxnss = None
42
+ input_cols = smiles_cols
43
+ elif smiles_cols is None and rxn_cols is not None:
44
+ smiss = None
45
+ rxnss = df[rxn_cols].T.values.tolist()
46
+ input_cols = rxn_cols
47
+ else:
48
+ smiss = df.iloc[:, [0]].T.values.tolist()
49
+ rxnss = None
50
+ input_cols = [df.columns[0]]
51
+
52
+ if target_cols is None:
53
+ target_cols = list(
54
+ column
55
+ for column in df.columns
56
+ if column
57
+ not in set( # if splits or weight is None, df.columns will never have None
58
+ input_cols + (ignore_cols or []) + [splits_col] + [weight_col]
59
+ )
60
+ )
61
+
62
+ Y = df[target_cols]
63
+ weights = None if weight_col is None else df[weight_col].to_numpy(np.single)
64
+
65
+ if bounded:
66
+ lt_mask = Y.applymap(lambda x: "<" in x).to_numpy()
67
+ gt_mask = Y.applymap(lambda x: ">" in x).to_numpy()
68
+ Y = Y.applymap(lambda x: x.strip("<").strip(">")).to_numpy(np.single)
69
+ else:
70
+ Y = Y.to_numpy(np.single)
71
+ lt_mask = None
72
+ gt_mask = None
73
+
74
+ return smiss, rxnss, Y, weights, lt_mask, gt_mask
75
+
76
+
77
+ def get_column_names(
78
+ path: PathLike,
79
+ smiles_cols: Sequence[str] | None,
80
+ rxn_cols: Sequence[str] | None,
81
+ target_cols: Sequence[str] | None,
82
+ ignore_cols: Sequence[str] | None,
83
+ splits_col: str | None,
84
+ weight_col: str | None,
85
+ no_header_row: bool = False,
86
+ ) -> tuple[list[str], list[str]]:
87
+ df_cols = pd.read_csv(path, index_col=False, nrows=0).columns.tolist()
88
+
89
+ if no_header_row:
90
+ return ["SMILES"], ["pred_" + str(i) for i in range((len(df_cols) - 1))]
91
+
92
+ input_cols = (smiles_cols or []) + (rxn_cols or [])
93
+
94
+ if len(input_cols) == 0:
95
+ input_cols = [df_cols[0]]
96
+
97
+ if target_cols is None:
98
+ target_cols = list(
99
+ column
100
+ for column in df_cols
101
+ if column
102
+ not in set(
103
+ input_cols + (ignore_cols or []) + ([splits_col] or []) + ([weight_col] or [])
104
+ )
105
+ )
106
+
107
+ return input_cols, target_cols
108
+
109
+
110
+ def make_datapoints(
111
+ smiss: list[list[str]] | None,
112
+ rxnss: list[list[str]] | None,
113
+ Y: np.ndarray,
114
+ weights: np.ndarray | None,
115
+ lt_mask: np.ndarray | None,
116
+ gt_mask: np.ndarray | None,
117
+ X_d: np.ndarray | None,
118
+ V_fss: list[list[np.ndarray] | list[None]] | None,
119
+ E_fss: list[list[np.ndarray] | list[None]] | None,
120
+ V_dss: list[list[np.ndarray] | list[None]] | None,
121
+ molecule_featurizers: list[str] | None,
122
+ keep_h: bool,
123
+ add_h: bool,
124
+ ) -> tuple[list[list[MoleculeDatapoint]], list[list[ReactionDatapoint]]]:
125
+ """Make the :class:`MoleculeDatapoint`s and :class:`ReactionDatapoint`s for a given
126
+ dataset.
127
+
128
+ Parameters
129
+ ----------
130
+ smiss : list[list[str]] | None
131
+ a list of ``j`` lists of ``n`` SMILES strings, where ``j`` is the number of molecules per
132
+ datapoint and ``n`` is the number of datapoints. If ``None``, the corresponding list of
133
+ :class:`MoleculeDatapoint`\s will be empty.
134
+ rxnss : list[list[str]] | None
135
+ a list of ``k`` lists of ``n`` reaction SMILES strings, where ``k`` is the number of
136
+ reactions per datapoint. If ``None``, the corresponding list of :class:`ReactionDatapoint`\s
137
+ will be empty.
138
+ Y : np.ndarray
139
+ the target values of shape ``n x m``, where ``m`` is the number of targets
140
+ weights : np.ndarray | None
141
+ the weights of the datapoints to use in the loss function of shape ``n x m``. If ``None``,
142
+ the weights all default to 1.
143
+ lt_mask : np.ndarray | None
144
+ a boolean mask of shape ``n x m`` indicating whether the targets are less than inequality
145
+ targets. If ``None``, ``lt_mask`` for all datapoints will be ``None``.
146
+ gt_mask : np.ndarray | None
147
+ a boolean mask of shape ``n x m`` indicating whether the targets are greater than inequality
148
+ targets. If ``None``, ``gt_mask`` for all datapoints will be ``None``.
149
+ X_d : np.ndarray | None
150
+ the extra descriptors of shape ``n x p``, where ``p`` is the number of extra descriptors. If
151
+ ``None``, ``x_d`` for all datapoints will be ``None``.
152
+ V_fss : list[list[np.ndarray] | list[None]] | None
153
+ a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x q_j``, where ``v_jn`` is
154
+ the number of atoms in the j-th molecule of the n-th datapoint and ``q_j`` is the number of
155
+ extra atom features used for the j-th molecules. Any of the ``j`` lists can be a list of
156
+ None values if the corresponding component does not use extra atom features. If ``None``,
157
+ ``V_f`` for all datapoints will be ``None``.
158
+ E_fss : list[list[np.ndarray] | list[None]] | None
159
+ a list of ``j`` lists of ``n`` np.ndarrays each of shape ``e_jn x r_j``, where ``e_jn`` is
160
+ the number of bonds in the j-th molecule of the n-th datapoint and ``r_j`` is the number of
161
+ extra bond features used for the j-th molecules. Any of the ``j`` lists can be a list of
162
+ None values if the corresponding component does not use extra bond features. If ``None``,
163
+ ``E_f`` for all datapoints will be ``None``.
164
+ V_dss : list[list[np.ndarray] | list[None]] | None
165
+ a list of ``j`` lists of ``n`` np.ndarrays each of shape ``v_jn x s_j``, where ``s_j`` is
166
+ the number of extra atom descriptors used for the j-th molecules. Any of the ``j`` lists can
167
+ be a list of None values if the corresponding component does not use extra atom features. If
168
+ ``None``, ``V_d`` for all datapoints will be ``None``.
169
+ molecule_featurizers : list[str] | None
170
+ a list of molecule featurizer names to generate additional molecule features to use as extra
171
+ descriptors. If there are multiple molecules per datapoint, the featurizers will be applied
172
+ to each molecule and concatenated. Note that a :code:`ReactionDatapoint` has two
173
+ RDKit :class:`~rdkit.Chem.Mol` objects, reactant(s) and product(s). Each
174
+ ``molecule_featurizer`` will be applied to both of these objects.
175
+ keep_h : bool
176
+ add_h : bool
177
+
178
+ Returns
179
+ -------
180
+ list[list[MoleculeDatapoint]]
181
+ a list of ``j`` lists of ``n`` :class:`MoleculeDatapoint`\s
182
+ list[list[ReactionDatapoint]]
183
+ a list of ``k`` lists of ``n`` :class:`ReactionDatapoint`\s
184
+ .. note::
185
+ either ``j`` or ``k`` may be 0, in which case the corresponding list will be empty.
186
+
187
+ Raises
188
+ ------
189
+ ValueError
190
+ if both ``smiss`` and ``rxnss`` are ``None``.
191
+ if ``smiss`` and ``rxnss`` are both given and have different lengths.
192
+ """
193
+ if smiss is None and rxnss is None:
194
+ raise ValueError("args 'smiss' and 'rnxss' were both `None`!")
195
+ elif rxnss is None:
196
+ N = len(smiss[0])
197
+ rxnss = []
198
+ elif smiss is None:
199
+ N = len(rxnss[0])
200
+ smiss = []
201
+ elif len(smiss[0]) != len(rxnss[0]):
202
+ raise ValueError(
203
+ f"args 'smiss' and 'rxnss' must have same length! got {len(smiss[0])} and {len(rxnss[0])}"
204
+ )
205
+ else:
206
+ N = len(smiss[0])
207
+
208
+ if len(smiss) > 0:
209
+ molss = [[make_mol(smi, keep_h, add_h) for smi in smis] for smis in smiss]
210
+ if len(rxnss) > 0:
211
+ rctss = [
212
+ [
213
+ make_mol(f"{rct_smi}.{agt_smi}" if agt_smi else rct_smi, keep_h, add_h)
214
+ for rct_smi, agt_smi, _ in (rxn.split(">") for rxn in rxns)
215
+ ]
216
+ for rxns in rxnss
217
+ ]
218
+ pdtss = [
219
+ [make_mol(pdt_smi, keep_h, add_h) for _, _, pdt_smi in (rxn.split(">") for rxn in rxns)]
220
+ for rxns in rxnss
221
+ ]
222
+
223
+ weights = np.ones(N, dtype=np.single) if weights is None else weights
224
+ gt_mask = [None] * N if gt_mask is None else gt_mask
225
+ lt_mask = [None] * N if lt_mask is None else lt_mask
226
+
227
+ n_mols = len(smiss) if smiss else 0
228
+ V_fss = [[None] * N] * n_mols if V_fss is None else V_fss
229
+ E_fss = [[None] * N] * n_mols if E_fss is None else E_fss
230
+ V_dss = [[None] * N] * n_mols if V_dss is None else V_dss
231
+
232
+ if X_d is None and molecule_featurizers is None:
233
+ X_d = [None] * N
234
+ elif molecule_featurizers is None:
235
+ pass
236
+ else:
237
+ molecule_featurizers = [MoleculeFeaturizerRegistry[mf]() for mf in molecule_featurizers]
238
+
239
+ if len(smiss) > 0:
240
+ mol_descriptors = np.hstack(
241
+ [
242
+ np.vstack([np.hstack([mf(mol) for mf in molecule_featurizers]) for mol in mols])
243
+ for mols in molss
244
+ ]
245
+ )
246
+ if X_d is None:
247
+ X_d = mol_descriptors
248
+ else:
249
+ X_d = np.hstack([X_d, mol_descriptors])
250
+
251
+ if len(rxnss) > 0:
252
+ rct_pdt_descriptors = np.hstack(
253
+ [
254
+ np.vstack(
255
+ [
256
+ np.hstack(
257
+ [mf(mol) for mf in molecule_featurizers for mol in (rct, pdt)]
258
+ )
259
+ for rct, pdt in zip(rcts, pdts)
260
+ ]
261
+ )
262
+ for rcts, pdts in zip(rctss, pdtss)
263
+ ]
264
+ )
265
+ if X_d is None:
266
+ X_d = rct_pdt_descriptors
267
+ else:
268
+ X_d = np.hstack([X_d, rct_pdt_descriptors])
269
+
270
+ mol_data = [
271
+ [
272
+ MoleculeDatapoint(
273
+ mol=molss[mol_idx][i],
274
+ name=smis[i],
275
+ y=Y[i],
276
+ weight=weights[i],
277
+ gt_mask=gt_mask[i],
278
+ lt_mask=lt_mask[i],
279
+ x_d=X_d[i],
280
+ x_phase=None,
281
+ V_f=V_fss[mol_idx][i],
282
+ E_f=E_fss[mol_idx][i],
283
+ V_d=V_dss[mol_idx][i],
284
+ )
285
+ for i in range(N)
286
+ ]
287
+ for mol_idx, smis in enumerate(smiss)
288
+ ]
289
+ rxn_data = [
290
+ [
291
+ ReactionDatapoint(
292
+ rct=rctss[rxn_idx][i],
293
+ pdt=pdtss[rxn_idx][i],
294
+ name=rxns[i],
295
+ y=Y[i],
296
+ weight=weights[i],
297
+ gt_mask=gt_mask[i],
298
+ lt_mask=lt_mask[i],
299
+ x_d=X_d[i],
300
+ x_phase=None,
301
+ )
302
+ for i in range(N)
303
+ ]
304
+ for rxn_idx, rxns in enumerate(rxnss)
305
+ ]
306
+
307
+ return mol_data, rxn_data
308
+
309
+
310
+ def build_data_from_files(
311
+ p_data: PathLike,
312
+ no_header_row: bool,
313
+ smiles_cols: Sequence[str] | None,
314
+ rxn_cols: Sequence[str] | None,
315
+ target_cols: Sequence[str] | None,
316
+ ignore_cols: Sequence[str] | None,
317
+ splits_col: str | None,
318
+ weight_col: str | None,
319
+ bounded: bool,
320
+ p_descriptors: PathLike,
321
+ p_atom_feats: dict[int, PathLike],
322
+ p_bond_feats: dict[int, PathLike],
323
+ p_atom_descs: dict[int, PathLike],
324
+ **featurization_kwargs: Mapping,
325
+ ) -> list[list[MoleculeDatapoint] | list[ReactionDatapoint]]:
326
+ smiss, rxnss, Y, weights, lt_mask, gt_mask = parse_csv(
327
+ p_data,
328
+ smiles_cols,
329
+ rxn_cols,
330
+ target_cols,
331
+ ignore_cols,
332
+ splits_col,
333
+ weight_col,
334
+ bounded,
335
+ no_header_row,
336
+ )
337
+ n_molecules = len(smiss) if smiss is not None else 0
338
+ n_datapoints = len(Y)
339
+
340
+ X_ds = load_input_feats_and_descs(p_descriptors, None, None, feat_desc="X_d")
341
+ V_fss = load_input_feats_and_descs(p_atom_feats, n_molecules, n_datapoints, feat_desc="V_f")
342
+ E_fss = load_input_feats_and_descs(p_bond_feats, n_molecules, n_datapoints, feat_desc="E_f")
343
+ V_dss = load_input_feats_and_descs(p_atom_descs, n_molecules, n_datapoints, feat_desc="V_d")
344
+
345
+ mol_data, rxn_data = make_datapoints(
346
+ smiss,
347
+ rxnss,
348
+ Y,
349
+ weights,
350
+ lt_mask,
351
+ gt_mask,
352
+ X_ds,
353
+ V_fss,
354
+ E_fss,
355
+ V_dss,
356
+ **featurization_kwargs,
357
+ )
358
+
359
+ return mol_data + rxn_data
360
+
361
+
362
+ def load_input_feats_and_descs(
363
+ paths: dict[int, PathLike] | PathLike,
364
+ n_molecules: int | None,
365
+ n_datapoints: int | None,
366
+ feat_desc: str,
367
+ ):
368
+ if paths is None:
369
+ return None
370
+
371
+ match feat_desc:
372
+ case "X_d":
373
+ path = paths
374
+ loaded_feature = np.load(path)
375
+ features = loaded_feature["arr_0"]
376
+
377
+ case _:
378
+ for index in paths:
379
+ if index >= n_molecules:
380
+ raise ValueError(
381
+ f"For {n_molecules} molecules, atom/bond features/descriptors can only be "
382
+ f"specified for indices 0-{n_molecules - 1}! Got index {index}."
383
+ )
384
+
385
+ features = []
386
+ for idx in range(n_molecules):
387
+ path = paths.get(idx, None)
388
+
389
+ if path is not None:
390
+ loaded_feature = np.load(path)
391
+ loaded_feature = [
392
+ loaded_feature[f"arr_{i}"] for i in range(len(loaded_feature))
393
+ ]
394
+ else:
395
+ loaded_feature = [None] * n_datapoints
396
+
397
+ features.append(loaded_feature)
398
+ return features
399
+
400
+
401
+ def make_dataset(
402
+ data: Sequence[MoleculeDatapoint] | Sequence[ReactionDatapoint],
403
+ reaction_mode: str,
404
+ multi_hot_atom_featurizer_mode: Literal["V1", "V2", "ORGANIC", "RIGR"] = "V2",
405
+ ) -> MoleculeDataset | ReactionDataset:
406
+ atom_featurizer = get_multi_hot_atom_featurizer(multi_hot_atom_featurizer_mode)
407
+ match multi_hot_atom_featurizer_mode:
408
+ case "RIGR":
409
+ bond_featurizer = RIGRBondFeaturizer()
410
+ case "V1" | "V2" | "ORGANIC":
411
+ bond_featurizer = MultiHotBondFeaturizer()
412
+ case _:
413
+ raise TypeError(
414
+ f"Unsupported atom featurizer mode '{multi_hot_atom_featurizer_mode=}'!"
415
+ )
416
+
417
+ if isinstance(data[0], MoleculeDatapoint):
418
+ extra_atom_fdim = data[0].V_f.shape[1] if data[0].V_f is not None else 0
419
+ extra_bond_fdim = data[0].E_f.shape[1] if data[0].E_f is not None else 0
420
+ featurizer = SimpleMoleculeMolGraphFeaturizer(
421
+ atom_featurizer=atom_featurizer,
422
+ bond_featurizer=bond_featurizer,
423
+ extra_atom_fdim=extra_atom_fdim,
424
+ extra_bond_fdim=extra_bond_fdim,
425
+ )
426
+ return MoleculeDataset(data, featurizer)
427
+
428
+ featurizer = CondensedGraphOfReactionFeaturizer(
429
+ mode_=reaction_mode, atom_featurizer=atom_featurizer
430
+ )
431
+
432
+ return ReactionDataset(data, featurizer)
433
+
434
+
435
+ def parse_indices(idxs):
436
+ """Parses a string of indices into a list of integers. e.g. '0,1,2-4' -> [0, 1, 2, 3, 4]"""
437
+ if isinstance(idxs, str):
438
+ indices = []
439
+ for idx in idxs.split(","):
440
+ if "-" in idx:
441
+ start, end = map(int, idx.split("-"))
442
+ indices.extend(range(start, end + 1))
443
+ else:
444
+ indices.append(int(idx))
445
+ return indices
446
+ return idxs
chemprop-updated/chemprop/cli/utils/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ __all__ = ["pop_attr"]
4
+
5
+
6
+ def pop_attr(o: object, attr: str, *args) -> Any | None:
7
+ """like ``pop()`` but for attribute maps"""
8
+ match len(args):
9
+ case 0:
10
+ return _pop_attr(o, attr)
11
+ case 1:
12
+ return _pop_attr_d(o, attr, args[0])
13
+ case _:
14
+ raise TypeError(f"Expected at most 2 arguments! got: {len(args)}")
15
+
16
+
17
+ def _pop_attr(o: object, attr: str) -> Any:
18
+ val = getattr(o, attr)
19
+ delattr(o, attr)
20
+
21
+ return val
22
+
23
+
24
+ def _pop_attr_d(o: object, attr: str, default: Any | None = None) -> Any | None:
25
+ try:
26
+ val = getattr(o, attr)
27
+ delattr(o, attr)
28
+ except AttributeError:
29
+ val = default
30
+
31
+ return val
chemprop-updated/chemprop/conf.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """Global configuration variables for chemprop"""
2
+
3
+ from chemprop.featurizers.molgraph.molecule import SimpleMoleculeMolGraphFeaturizer
4
+
5
+ DEFAULT_ATOM_FDIM, DEFAULT_BOND_FDIM = SimpleMoleculeMolGraphFeaturizer().shape
6
+ DEFAULT_HIDDEN_DIM = 300