Yuan (Cyrus) Chiang Elizabeth Weaver commited on
Commit
35c7a66
·
unverified ·
1 Parent(s): 9b93949

Add force equivariance benchmark (#65)

Browse files

* add equivariance testing

* mv to benchmarks folder; ruff

* deterministic discovery of pytest

---------

Co-authored-by: Elizabeth Weaver <e.j.weaver64@gmail.com>

Files changed (2) hide show
  1. benchmarks/force_equivariance/run.py +307 -0
  2. pytest.ini +3 -0
benchmarks/force_equivariance/run.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Define equivariance testing task.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections.abc import Sequence
8
+ from pathlib import Path
9
+
10
+ import numpy as np
11
+ from ase import Atoms
12
+ from prefect import task
13
+ from scipy.spatial.transform import Rotation as R
14
+ from tqdm import tqdm
15
+
16
+
17
+ def generate_random_unit_vector():
18
+ """Generate a random unit vector."""
19
+ vec = np.random.normal(0, 1, 3)
20
+ return vec / np.linalg.norm(vec)
21
+
22
+
23
+ def rotate_molecule_arbitrary(
24
+ atoms: Atoms, angle: float, axis: np.ndarray
25
+ ) -> tuple[Atoms, np.ndarray]:
26
+ """Rotate molecule around arbitrary axis."""
27
+ rotated_atoms = atoms.copy()
28
+ positions = rotated_atoms.get_positions()
29
+ rot = R.from_rotvec(np.radians(angle) * axis)
30
+ rotation_mat = rot.as_matrix()
31
+ rotated_positions = rot.apply(positions)
32
+ rotated_atoms.set_positions(rotated_positions)
33
+ cell = atoms.get_cell()
34
+ rotated_cell = rot.apply(cell)
35
+ rotated_atoms.set_cell(rotated_cell)
36
+ return rotated_atoms, rotation_mat
37
+
38
+
39
+ def compare_forces(
40
+ original_forces: np.ndarray,
41
+ rotated_forces: np.ndarray,
42
+ rotation_mat: np.ndarray,
43
+ zero_threshold: float = 1e-10,
44
+ ) -> tuple[float, np.ndarray, np.ndarray, np.ndarray]:
45
+ """
46
+ Compare forces before and after rotation, with handling of 0 force case.
47
+
48
+ Args:
49
+ original_forces: Forces before rotation (N x 3 array)
50
+ rotated_forces: Forces after rotation (N x 3 array)
51
+ rotation_mat: 3 x 3 rotation matrix
52
+ zero_threshold: Threshold below which forces are considered zero
53
+
54
+ Returns:
55
+ tuple containing:
56
+ - mae: Mean absolute error between forces
57
+ - cosine_similarity: Cosine similarity between force vectors
58
+ """
59
+ rotated_original_forces = np.dot(original_forces, rotation_mat.T)
60
+ force_diff = rotated_original_forces - rotated_forces
61
+ mae = np.mean(np.abs(force_diff))
62
+
63
+ original_magnitudes = np.linalg.norm(rotated_original_forces, axis=1)
64
+ rotated_magnitudes = np.linalg.norm(rotated_forces, axis=1)
65
+
66
+ zero_original = original_magnitudes < zero_threshold
67
+ zero_rotated = rotated_magnitudes < zero_threshold
68
+ both_zero = zero_original & zero_rotated
69
+ either_zero = zero_original | zero_rotated
70
+ one_zero = either_zero & ~both_zero
71
+
72
+ cosine_similarity = np.zeros(len(original_forces))
73
+
74
+ valid_forces = ~either_zero
75
+ if np.any(valid_forces):
76
+ norms_product = np.linalg.norm(
77
+ rotated_original_forces[valid_forces], axis=1
78
+ ) * np.linalg.norm(rotated_forces[valid_forces], axis=1)
79
+ dot_products = np.sum(
80
+ rotated_original_forces[valid_forces] * rotated_forces[valid_forces], axis=1
81
+ )
82
+ cosine_similarity[valid_forces] = dot_products / norms_product
83
+
84
+ # If both forces are 0, cosine similarity should be 1. If one is 0, we take the conservative -1.
85
+ cosine_similarity[both_zero] = 1.0
86
+ cosine_similarity[one_zero] = -1.0
87
+
88
+ return mae, cosine_similarity
89
+
90
+
91
+ def save_molecule_results(
92
+ aggregate_results: dict, idx_list: np.ndarray, save_path: str | Path
93
+ ) -> None:
94
+ """
95
+ Save all molecule results from equivariance testing to .npy files.
96
+ Save the index list of the atoms for further analysis.
97
+
98
+ Args:
99
+ aggregate_results: Dictionary containing the aggregated results from run()
100
+ idx_list: List of the indices of the atoms in the original dataset
101
+ save_path: Path to save the .npy files
102
+ """
103
+ save_path = Path(save_path)
104
+ save_path.parent.mkdir(parents=True, exist_ok=True)
105
+
106
+ all_molecule_results = aggregate_results["molecule_results"]
107
+ rotation_angles = list(all_molecule_results[0]["results_by_angle"].keys())
108
+
109
+ num_molecules = len(all_molecule_results)
110
+ num_angles = len(rotation_angles)
111
+ num_random_axes = len(
112
+ all_molecule_results[0]["results_by_angle"][rotation_angles[0]]["maes"]
113
+ )
114
+ num_atoms = len(
115
+ all_molecule_results[0]["results_by_angle"][rotation_angles[0]][
116
+ "cosine_similarities"
117
+ ][0]
118
+ )
119
+
120
+ maes = np.zeros((num_molecules, num_angles, num_random_axes))
121
+ cosine_similarities = np.zeros((num_molecules, num_angles, num_random_axes))
122
+
123
+ for mol_idx, molecule in enumerate(all_molecule_results):
124
+ for angle_idx, angle in enumerate(rotation_angles):
125
+ angle_results = molecule["results_by_angle"][angle]
126
+ maes[mol_idx, angle_idx, :] = angle_results["maes"]
127
+ cosine_similarities[mol_idx, angle_idx, :] = np.mean(
128
+ angle_results["cosine_similarities"], axis=-1
129
+ )
130
+
131
+ np.save(save_path.with_name(f"{save_path.stem}_maes.npy"), maes)
132
+ np.save(
133
+ save_path.with_name(f"{save_path.stem}_cosine_similarities.npy"),
134
+ cosine_similarities,
135
+ )
136
+ np.save(save_path.with_name(f"{save_path.stem}_idx_list.npy"), idx_list)
137
+
138
+
139
+ @task(
140
+ name="Equivariance testing",
141
+ task_run_name=_generate_task_run_name,
142
+ cache_policy=TASK_SOURCE + INPUTS,
143
+ )
144
+ def run(
145
+ atoms_list: Sequence[Atoms],
146
+ idx_list: np.ndarray,
147
+ calculator: BaseCalculator,
148
+ save_path: str | Path | None = None,
149
+ rotation_angles: list[float] | np.ndarray = None,
150
+ num_random_axes: int = 100,
151
+ threshold: float = 1e-3,
152
+ seed: int | None = None,
153
+ ) -> dict:
154
+ """
155
+ Test equivariance of force predictions under rotations for multiple structures.
156
+
157
+ Args:
158
+ atoms_list: List of input atomic structures
159
+ idx_list: List of the indices of the atoms in the original dataset
160
+ calculator: Calculator to use
161
+ num_rotations: Number of random rotations to test
162
+ rotation_angle: Angle of rotation in degrees
163
+ threshold: Threshold for considering forces equivariant
164
+ seed: Random seed
165
+
166
+ Returns:
167
+ Dictionary containing test results
168
+ """
169
+ if seed is not None:
170
+ np.random.seed(seed)
171
+
172
+ if rotation_angles is None:
173
+ rotation_angles = np.arange(30, 361, 30)
174
+ rotation_angles = np.array(rotation_angles)
175
+
176
+ all_results = []
177
+
178
+ cross_molecule_cosine_sims = {angle: [] for angle in rotation_angles}
179
+ cross_molecule_mae = {angle: [] for angle in rotation_angles}
180
+
181
+ rotation_axes = [generate_random_unit_vector() for _ in range(num_random_axes)]
182
+
183
+ total_tests = len(atoms_list) * len(rotation_angles) * num_random_axes
184
+ pbar = tqdm(total=total_tests, desc="Testing rotations")
185
+
186
+ for atom_idx, atoms in enumerate(atoms_list):
187
+ atoms = atoms.copy()
188
+ atoms.calc = calculator
189
+ original_forces = atoms.get_forces()
190
+
191
+ results_by_angle = {
192
+ angle: {
193
+ "mae": [],
194
+ "cosine_similarities": [],
195
+ "passed_tests": 0,
196
+ "passed_mae": 0,
197
+ "passed_cosine_similarity": 0,
198
+ }
199
+ for angle in rotation_angles
200
+ }
201
+ # Test each angle with multiple random axes
202
+ for angle in rotation_angles:
203
+ for axis in rotation_axes:
204
+ rotated_atoms, rotation_mat = rotate_molecule_arbitrary(
205
+ atoms, angle, axis
206
+ )
207
+ rotated_atoms.calc = calculator
208
+ rotated_forces = rotated_atoms.get_forces()
209
+ mae, cosine_similarity = compare_forces(
210
+ original_forces, rotated_forces, rotation_mat
211
+ )
212
+ results_by_angle[angle]["mae"].append(mae)
213
+ results_by_angle[angle]["cosine_similarities"].append(cosine_similarity)
214
+
215
+ cross_molecule_cosine_sims[angle].append(
216
+ float(np.mean(cosine_similarity))
217
+ )
218
+ cross_molecule_mae[angle].append(float(np.mean(mae)))
219
+
220
+ mae_check = mae < threshold
221
+ cosine_check = all(cosine_similarity > (1 - threshold))
222
+ results_by_angle[angle]["passed_tests"] += int(
223
+ mae_check and cosine_check
224
+ )
225
+ results_by_angle[angle]["passed_mae"] += int(mae_check)
226
+ results_by_angle[angle]["passed_cosine_similarity"] += int(cosine_check)
227
+
228
+ pbar.update(1)
229
+ # Compute summary statistics
230
+ for angle in rotation_angles:
231
+ results = results_by_angle[angle]
232
+ results["mean_cosine_similarity"] = float(
233
+ np.mean(results["cosine_similarities"])
234
+ )
235
+ results["avg_mae"] = float(np.mean(results["mae"]))
236
+ results["equivariant_ratio"] = results["passed_tests"] / num_random_axes
237
+ results["mae_passed_ratio"] = results["passed_mae"] / num_random_axes
238
+ results["cosine_passed_ratio"] = (
239
+ results["passed_cosine_similarity"] / num_random_axes
240
+ )
241
+ results["passed"] = results["passed_tests"] == num_random_axes
242
+ results["passed_mae"] = results["passed_mae"] == num_random_axes
243
+ results["passed_cosine_similarity"] = (
244
+ results["passed_cosine_similarity"] == num_random_axes
245
+ )
246
+ results["maes"] = [float(x) for x in results["mae"]]
247
+ results["cosine_similarities"] = [
248
+ [float(y) for y in x] for x in results["cosine_similarities"]
249
+ ]
250
+
251
+ molecule_results = {
252
+ "mol_idx": idx_list[atom_idx],
253
+ "results_by_angle": results_by_angle,
254
+ "all_passed": all(
255
+ results_by_angle[angle]["passed"] for angle in rotation_angles
256
+ ),
257
+ "avg_cosine_similarity_by_molecule": float(
258
+ np.mean(
259
+ [
260
+ results_by_angle[angle]["mean_cosine_similarity"]
261
+ for angle in rotation_angles
262
+ ]
263
+ )
264
+ ),
265
+ "avg_mae_by_molecule": float(
266
+ np.mean(
267
+ [results_by_angle[angle]["avg_mae"] for angle in rotation_angles]
268
+ )
269
+ ),
270
+ "overall_equivariant_ratio": float(
271
+ np.mean(
272
+ [
273
+ results_by_angle[angle]["equivariant_ratio"]
274
+ for angle in rotation_angles
275
+ ]
276
+ )
277
+ ),
278
+ }
279
+
280
+ all_results.append(molecule_results)
281
+
282
+ pbar.close()
283
+
284
+ aggregate_results = {
285
+ "num_molecules": len(atoms_list),
286
+ "all_molecules_passed": all(result["all_passed"] for result in all_results),
287
+ "average_equivariant_ratio": float(
288
+ np.mean([result["overall_equivariant_ratio"] for result in all_results])
289
+ ),
290
+ "average_cosine_similarity_by_angle": {
291
+ angle: float(np.mean(sims))
292
+ for angle, sims in cross_molecule_cosine_sims.items()
293
+ },
294
+ "average_mae_by_angle": {
295
+ angle: float(np.mean(diffs)) for angle, diffs in cross_molecule_mae.items()
296
+ },
297
+ "molecule_results": all_results,
298
+ }
299
+
300
+ if save_path:
301
+ save_molecule_results(aggregate_results, idx_list, save_path)
302
+ np.save(
303
+ str(save_path.with_name(f"{save_path.stem}_molecule_results.npy")),
304
+ all_results,
305
+ )
306
+
307
+ return aggregate_results
pytest.ini ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [pytest]
2
+ testpaths = tests
3
+ python_files = test_*.py