Yuan (Cyrus) Chiang
Add force equivariance benchmark (#65)
35c7a66 unverified
"""
Define equivariance testing task.
"""
from __future__ import annotations
from collections.abc import Sequence
from pathlib import Path
import numpy as np
from ase import Atoms
from prefect import task
from scipy.spatial.transform import Rotation as R
from tqdm import tqdm
def generate_random_unit_vector():
"""Generate a random unit vector."""
vec = np.random.normal(0, 1, 3)
return vec / np.linalg.norm(vec)
def rotate_molecule_arbitrary(
atoms: Atoms, angle: float, axis: np.ndarray
) -> tuple[Atoms, np.ndarray]:
"""Rotate molecule around arbitrary axis."""
rotated_atoms = atoms.copy()
positions = rotated_atoms.get_positions()
rot = R.from_rotvec(np.radians(angle) * axis)
rotation_mat = rot.as_matrix()
rotated_positions = rot.apply(positions)
rotated_atoms.set_positions(rotated_positions)
cell = atoms.get_cell()
rotated_cell = rot.apply(cell)
rotated_atoms.set_cell(rotated_cell)
return rotated_atoms, rotation_mat
def compare_forces(
original_forces: np.ndarray,
rotated_forces: np.ndarray,
rotation_mat: np.ndarray,
zero_threshold: float = 1e-10,
) -> tuple[float, np.ndarray, np.ndarray, np.ndarray]:
"""
Compare forces before and after rotation, with handling of 0 force case.
Args:
original_forces: Forces before rotation (N x 3 array)
rotated_forces: Forces after rotation (N x 3 array)
rotation_mat: 3 x 3 rotation matrix
zero_threshold: Threshold below which forces are considered zero
Returns:
tuple containing:
- mae: Mean absolute error between forces
- cosine_similarity: Cosine similarity between force vectors
"""
rotated_original_forces = np.dot(original_forces, rotation_mat.T)
force_diff = rotated_original_forces - rotated_forces
mae = np.mean(np.abs(force_diff))
original_magnitudes = np.linalg.norm(rotated_original_forces, axis=1)
rotated_magnitudes = np.linalg.norm(rotated_forces, axis=1)
zero_original = original_magnitudes < zero_threshold
zero_rotated = rotated_magnitudes < zero_threshold
both_zero = zero_original & zero_rotated
either_zero = zero_original | zero_rotated
one_zero = either_zero & ~both_zero
cosine_similarity = np.zeros(len(original_forces))
valid_forces = ~either_zero
if np.any(valid_forces):
norms_product = np.linalg.norm(
rotated_original_forces[valid_forces], axis=1
) * np.linalg.norm(rotated_forces[valid_forces], axis=1)
dot_products = np.sum(
rotated_original_forces[valid_forces] * rotated_forces[valid_forces], axis=1
)
cosine_similarity[valid_forces] = dot_products / norms_product
# If both forces are 0, cosine similarity should be 1. If one is 0, we take the conservative -1.
cosine_similarity[both_zero] = 1.0
cosine_similarity[one_zero] = -1.0
return mae, cosine_similarity
def save_molecule_results(
aggregate_results: dict, idx_list: np.ndarray, save_path: str | Path
) -> None:
"""
Save all molecule results from equivariance testing to .npy files.
Save the index list of the atoms for further analysis.
Args:
aggregate_results: Dictionary containing the aggregated results from run()
idx_list: List of the indices of the atoms in the original dataset
save_path: Path to save the .npy files
"""
save_path = Path(save_path)
save_path.parent.mkdir(parents=True, exist_ok=True)
all_molecule_results = aggregate_results["molecule_results"]
rotation_angles = list(all_molecule_results[0]["results_by_angle"].keys())
num_molecules = len(all_molecule_results)
num_angles = len(rotation_angles)
num_random_axes = len(
all_molecule_results[0]["results_by_angle"][rotation_angles[0]]["maes"]
)
num_atoms = len(
all_molecule_results[0]["results_by_angle"][rotation_angles[0]][
"cosine_similarities"
][0]
)
maes = np.zeros((num_molecules, num_angles, num_random_axes))
cosine_similarities = np.zeros((num_molecules, num_angles, num_random_axes))
for mol_idx, molecule in enumerate(all_molecule_results):
for angle_idx, angle in enumerate(rotation_angles):
angle_results = molecule["results_by_angle"][angle]
maes[mol_idx, angle_idx, :] = angle_results["maes"]
cosine_similarities[mol_idx, angle_idx, :] = np.mean(
angle_results["cosine_similarities"], axis=-1
)
np.save(save_path.with_name(f"{save_path.stem}_maes.npy"), maes)
np.save(
save_path.with_name(f"{save_path.stem}_cosine_similarities.npy"),
cosine_similarities,
)
np.save(save_path.with_name(f"{save_path.stem}_idx_list.npy"), idx_list)
@task(
name="Equivariance testing",
task_run_name=_generate_task_run_name,
cache_policy=TASK_SOURCE + INPUTS,
)
def run(
atoms_list: Sequence[Atoms],
idx_list: np.ndarray,
calculator: BaseCalculator,
save_path: str | Path | None = None,
rotation_angles: list[float] | np.ndarray = None,
num_random_axes: int = 100,
threshold: float = 1e-3,
seed: int | None = None,
) -> dict:
"""
Test equivariance of force predictions under rotations for multiple structures.
Args:
atoms_list: List of input atomic structures
idx_list: List of the indices of the atoms in the original dataset
calculator: Calculator to use
num_rotations: Number of random rotations to test
rotation_angle: Angle of rotation in degrees
threshold: Threshold for considering forces equivariant
seed: Random seed
Returns:
Dictionary containing test results
"""
if seed is not None:
np.random.seed(seed)
if rotation_angles is None:
rotation_angles = np.arange(30, 361, 30)
rotation_angles = np.array(rotation_angles)
all_results = []
cross_molecule_cosine_sims = {angle: [] for angle in rotation_angles}
cross_molecule_mae = {angle: [] for angle in rotation_angles}
rotation_axes = [generate_random_unit_vector() for _ in range(num_random_axes)]
total_tests = len(atoms_list) * len(rotation_angles) * num_random_axes
pbar = tqdm(total=total_tests, desc="Testing rotations")
for atom_idx, atoms in enumerate(atoms_list):
atoms = atoms.copy()
atoms.calc = calculator
original_forces = atoms.get_forces()
results_by_angle = {
angle: {
"mae": [],
"cosine_similarities": [],
"passed_tests": 0,
"passed_mae": 0,
"passed_cosine_similarity": 0,
}
for angle in rotation_angles
}
# Test each angle with multiple random axes
for angle in rotation_angles:
for axis in rotation_axes:
rotated_atoms, rotation_mat = rotate_molecule_arbitrary(
atoms, angle, axis
)
rotated_atoms.calc = calculator
rotated_forces = rotated_atoms.get_forces()
mae, cosine_similarity = compare_forces(
original_forces, rotated_forces, rotation_mat
)
results_by_angle[angle]["mae"].append(mae)
results_by_angle[angle]["cosine_similarities"].append(cosine_similarity)
cross_molecule_cosine_sims[angle].append(
float(np.mean(cosine_similarity))
)
cross_molecule_mae[angle].append(float(np.mean(mae)))
mae_check = mae < threshold
cosine_check = all(cosine_similarity > (1 - threshold))
results_by_angle[angle]["passed_tests"] += int(
mae_check and cosine_check
)
results_by_angle[angle]["passed_mae"] += int(mae_check)
results_by_angle[angle]["passed_cosine_similarity"] += int(cosine_check)
pbar.update(1)
# Compute summary statistics
for angle in rotation_angles:
results = results_by_angle[angle]
results["mean_cosine_similarity"] = float(
np.mean(results["cosine_similarities"])
)
results["avg_mae"] = float(np.mean(results["mae"]))
results["equivariant_ratio"] = results["passed_tests"] / num_random_axes
results["mae_passed_ratio"] = results["passed_mae"] / num_random_axes
results["cosine_passed_ratio"] = (
results["passed_cosine_similarity"] / num_random_axes
)
results["passed"] = results["passed_tests"] == num_random_axes
results["passed_mae"] = results["passed_mae"] == num_random_axes
results["passed_cosine_similarity"] = (
results["passed_cosine_similarity"] == num_random_axes
)
results["maes"] = [float(x) for x in results["mae"]]
results["cosine_similarities"] = [
[float(y) for y in x] for x in results["cosine_similarities"]
]
molecule_results = {
"mol_idx": idx_list[atom_idx],
"results_by_angle": results_by_angle,
"all_passed": all(
results_by_angle[angle]["passed"] for angle in rotation_angles
),
"avg_cosine_similarity_by_molecule": float(
np.mean(
[
results_by_angle[angle]["mean_cosine_similarity"]
for angle in rotation_angles
]
)
),
"avg_mae_by_molecule": float(
np.mean(
[results_by_angle[angle]["avg_mae"] for angle in rotation_angles]
)
),
"overall_equivariant_ratio": float(
np.mean(
[
results_by_angle[angle]["equivariant_ratio"]
for angle in rotation_angles
]
)
),
}
all_results.append(molecule_results)
pbar.close()
aggregate_results = {
"num_molecules": len(atoms_list),
"all_molecules_passed": all(result["all_passed"] for result in all_results),
"average_equivariant_ratio": float(
np.mean([result["overall_equivariant_ratio"] for result in all_results])
),
"average_cosine_similarity_by_angle": {
angle: float(np.mean(sims))
for angle, sims in cross_molecule_cosine_sims.items()
},
"average_mae_by_angle": {
angle: float(np.mean(diffs)) for angle, diffs in cross_molecule_mae.items()
},
"molecule_results": all_results,
}
if save_path:
save_molecule_results(aggregate_results, idx_list, save_path)
np.save(
str(save_path.with_name(f"{save_path.stem}_molecule_results.npy")),
all_results,
)
return aggregate_results