In [1]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
from ase import Atom, Atoms
from ase.data import chemical_symbols, covalent_radii, vdw_alvarez
from ase.io import read, write
from pymatgen.core import Element
from scipy import stats
from tqdm.auto import tqdm

from mlip_arena.models.utils import REGISTRY, MLIPEnum

model_name = "EquiformerV2(OC20)"

calc = MLIPEnum[model_name].value()



In [2]:
for symbol in tqdm(chemical_symbols[1:]):

    s = set([symbol])

    if "X" in s:
        continue

    try:
        atom = Atom(symbol)
        rmin = min(covalent_radii[atom.number] * 0.5, 1.0)
        rvdw = vdw_alvarez.vdw_radii[atom.number] if atom.number < len(vdw_alvarez.vdw_radii) else np.nan
        rmax = 3.1 * rvdw if not np.isnan(rvdw) else 6
        rstep = 0.01 #if rmin < 1 else 0.4

        a = 2 * rmax

        npts = int((rmax - rmin)/rstep)

        rs = np.linspace(rmin, rmax, npts)
        e = np.zeros_like(rs)

        da = symbol + symbol

        out_dir = Path(REGISTRY[model_name]["family"]) / str(da)

        os.makedirs(out_dir, exist_ok=True)

        skip = 0

        element = Element(symbol)

        try:
            m = element.valence[1]
            if element.valence == (0, 2):
                m = 0
        except:
            m = 0


        r = rs[0]

        positions = [
            [a/2-r/2, a/2, a/2],
            [a/2+r/2, a/2, a/2],
        ]

        traj_fpath = out_dir / f"{model_name}.extxyz"

        if traj_fpath.exists():
            traj = read(traj_fpath, index=":")
            skip = len(traj)
            atoms = traj[-1]
        else:
            # Create the unit cell with two atoms
            atoms = Atoms(
                da,
                positions=positions,
                # magmoms=magmoms,
                cell=[a, a+0.001, a+0.002],
                pbc=True
            )

        print(atoms)

        calc = calc

        atoms.calc = calc

        for i, r in enumerate(tqdm(rs)):

            if i < skip:
                continue

            positions = [
                [a/2-r/2, a/2, a/2],
                [a/2+r/2, a/2, a/2],
            ]

            # atoms.set_initial_magnetic_moments(magmoms)

            atoms.set_positions(positions)

            e[i] = atoms.get_potential_energy()

            write(traj_fpath, atoms, append="a")
    except Exception as e:
        print(e)


  0%|          | 0/118 [00:00<?, ?it/s]

Atoms(symbols='H2', pbc=True, cell=[7.4399999999999995, 7.441, 7.441999999999999], calculator=SinglePointCalculator(...))


  0%|          | 0/356 [00:00<?, ?it/s]

Atoms(symbols='He2', pbc=True, cell=[8.866, 8.866999999999999, 8.868], calculator=SinglePointCalculator(...))


  0%|          | 0/429 [00:00<?, ?it/s]

Atoms(symbols='Li2', pbc=True, cell=[13.144000000000002, 13.145000000000001, 13.146000000000003], calculator=SinglePointCalculator(...))


  0%|          | 0/593 [00:00<?, ?it/s]

Atoms(symbols='Be2', pbc=True, cell=[12.276, 12.277, 12.278], calculator=SinglePointCalculator(...))


  0%|          | 0/565 [00:00<?, ?it/s]

Atoms(symbols='B2', pbc=True, cell=[11.842, 11.843, 11.844000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/550 [00:00<?, ?it/s]

Atoms(symbols='C2', pbc=True, cell=[10.974, 10.975, 10.976], calculator=SinglePointCalculator(...))


  0%|          | 0/510 [00:00<?, ?it/s]

Atoms(symbols='N2', pbc=True, cell=[10.292, 10.293, 10.294], calculator=SinglePointCalculator(...))


  0%|          | 0/479 [00:00<?, ?it/s]

Atoms(symbols='O2', pbc=True, cell=[9.3, 9.301, 9.302000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/432 [00:00<?, ?it/s]

Atoms(symbols='F2', pbc=True, cell=[9.052, 9.052999999999999, 9.054], calculator=SinglePointCalculator(...))


  0%|          | 0/424 [00:00<?, ?it/s]

Atoms(symbols='Ne2', pbc=True, cell=[9.796000000000001, 9.797, 9.798000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/460 [00:00<?, ?it/s]

Atoms(symbols='Na2', pbc=True, cell=[15.5, 15.501, 15.502], calculator=SinglePointCalculator(...))


  0%|          | 0/692 [00:00<?, ?it/s]

Atoms(symbols='Mg2', pbc=True, cell=[15.562, 15.562999999999999, 15.564], calculator=SinglePointCalculator(...))


  0%|          | 0/707 [00:00<?, ?it/s]

Atoms(symbols='Al2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/637 [00:00<?, ?it/s]

Atoms(symbols='Si2', pbc=True, cell=[13.578, 13.578999999999999, 13.58], calculator=SinglePointCalculator(...))


  0%|          | 0/623 [00:00<?, ?it/s]

Atoms(symbols='P2', pbc=True, cell=[11.78, 11.780999999999999, 11.782], calculator=SinglePointCalculator(...))


  0%|          | 0/535 [00:00<?, ?it/s]

Atoms(symbols='S2', pbc=True, cell=[11.718, 11.719, 11.72], calculator=SinglePointCalculator(...))


  0%|          | 0/533 [00:00<?, ?it/s]

Atoms(symbols='Cl2', pbc=True, cell=[11.284, 11.285, 11.286000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/513 [00:00<?, ?it/s]

Atoms(symbols='Ar2', pbc=True, cell=[11.346, 11.347, 11.348], calculator=SinglePointCalculator(...))


  0%|          | 0/514 [00:00<?, ?it/s]

Atoms(symbols='K2', pbc=True, cell=[16.926000000000002, 16.927000000000003, 16.928], calculator=SinglePointCalculator(...))


  0%|          | 0/746 [00:00<?, ?it/s]

Atoms(symbols='Ca2', pbc=True, cell=[16.244, 16.245, 16.246], calculator=SinglePointCalculator(...))


  0%|          | 0/724 [00:00<?, ?it/s]

Atoms(symbols='Sc2', pbc=True, cell=[15.996, 15.997, 15.998000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/714 [00:00<?, ?it/s]

Atoms(symbols='Ti2', pbc=True, cell=[15.252, 15.253, 15.254000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/682 [00:00<?, ?it/s]

Atoms(symbols='V2', pbc=True, cell=[15.004, 15.004999999999999, 15.006], calculator=SinglePointCalculator(...))


  0%|          | 0/673 [00:00<?, ?it/s]

Atoms(symbols='Cr2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/690 [00:00<?, ?it/s]

Atoms(symbols='Mn2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/690 [00:00<?, ?it/s]

Atoms(symbols='Fe2', pbc=True, cell=[15.128, 15.129, 15.13], calculator=SinglePointCalculator(...))


  0%|          | 0/690 [00:00<?, ?it/s]

Atoms(symbols='Co2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882], calculator=SinglePointCalculator(...))


  0%|          | 0/681 [00:00<?, ?it/s]

Atoms(symbols='Ni2', pbc=True, cell=[14.879999999999999, 14.880999999999998, 14.882], calculator=SinglePointCalculator(...))


  0%|          | 0/681 [00:00<?, ?it/s]

Atoms(symbols='Cu2', pbc=True, cell=[14.756, 14.757, 14.758000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/671 [00:00<?, ?it/s]

Atoms(symbols='Zn2', pbc=True, cell=[14.818000000000001, 14.819, 14.820000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/679 [00:00<?, ?it/s]

Atoms(symbols='Ga2', pbc=True, cell=[14.383999999999999, 14.384999999999998, 14.386], calculator=SinglePointCalculator(...))


  0%|          | 0/658 [00:00<?, ?it/s]

Atoms(symbols='Ge2', pbc=True, cell=[14.198, 14.199, 14.200000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/649 [00:00<?, ?it/s]

Atoms(symbols='As2', pbc=True, cell=[11.655999999999999, 11.656999999999998, 11.658], calculator=SinglePointCalculator(...))


  0%|          | 0/523 [00:00<?, ?it/s]

Atoms(symbols='Se2', pbc=True, cell=[11.284, 11.285, 11.286000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/504 [00:00<?, ?it/s]

Atoms(symbols='Br2', pbc=True, cell=[11.532000000000002, 11.533000000000001, 11.534000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/516 [00:00<?, ?it/s]

Atoms(symbols='Kr2', pbc=True, cell=[13.950000000000001, 13.951, 13.952000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/639 [00:00<?, ?it/s]

Atoms(symbols='Rb2', pbc=True, cell=[19.902, 19.903000000000002, 19.904], calculator=SinglePointCalculator(...))


  0%|          | 0/895 [00:00<?, ?it/s]

Atoms(symbols='Sr2', pbc=True, cell=[17.608, 17.609, 17.61], calculator=SinglePointCalculator(...))


  0%|          | 0/782 [00:00<?, ?it/s]

Atoms(symbols='Y2', pbc=True, cell=[17.05, 17.051000000000002, 17.052], calculator=SinglePointCalculator(...))


  0%|          | 0/757 [00:00<?, ?it/s]

Atoms(symbols='Zr2', pbc=True, cell=[15.624, 15.625, 15.626000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/693 [00:00<?, ?it/s]

Atoms(symbols='Nb2', pbc=True, cell=[15.872000000000002, 15.873000000000001, 15.874000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/711 [00:00<?, ?it/s]

Atoms(symbols='Mo2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/682 [00:00<?, ?it/s]

Atoms(symbols='Tc2', pbc=True, cell=[15.128, 15.129, 15.13], calculator=SinglePointCalculator(...))


  0%|          | 0/682 [00:00<?, ?it/s]

Atoms(symbols='Ru2', pbc=True, cell=[15.252, 15.253, 15.254000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/689 [00:00<?, ?it/s]

Atoms(symbols='Rh2', pbc=True, cell=[15.128, 15.129, 15.13], calculator=SinglePointCalculator(...))


  0%|          | 0/685 [00:00<?, ?it/s]

Atoms(symbols='Pd2', pbc=True, cell=[13.33, 13.331, 13.332], calculator=SinglePointCalculator(...))


  0%|          | 0/597 [00:00<?, ?it/s]

Atoms(symbols='Ag2', pbc=True, cell=[15.686, 15.687, 15.688], calculator=SinglePointCalculator(...))


  0%|          | 0/711 [00:00<?, ?it/s]

Atoms(symbols='Cd2', pbc=True, cell=[15.438000000000002, 15.439000000000002, 15.440000000000003], calculator=SinglePointCalculator(...))


  0%|          | 0/699 [00:00<?, ?it/s]

Atoms(symbols='In2', pbc=True, cell=[15.066, 15.067, 15.068000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/682 [00:00<?, ?it/s]

Atoms(symbols='Sn2', pbc=True, cell=[15.004, 15.004999999999999, 15.006], calculator=SinglePointCalculator(...))


  0%|          | 0/680 [00:00<?, ?it/s]

Atoms(symbols='Sb2', pbc=True, cell=[15.314000000000002, 15.315000000000001, 15.316000000000003], calculator=SinglePointCalculator(...))


  0%|          | 0/696 [00:00<?, ?it/s]

Atoms(symbols='Te2', pbc=True, cell=[12.338000000000001, 12.339, 12.340000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/547 [00:00<?, ?it/s]

Atoms(symbols='I2', pbc=True, cell=[12.648000000000001, 12.649000000000001, 12.650000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/562 [00:00<?, ?it/s]

Atoms(symbols='Xe2', pbc=True, cell=[12.772, 12.773, 12.774000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/568 [00:00<?, ?it/s]

Atoms(symbols='Cs2', pbc=True, cell=[21.576, 21.577, 21.578], calculator=SinglePointCalculator(...))


  0%|          | 0/978 [00:00<?, ?it/s]

Atoms(symbols='Ba2', pbc=True, cell=[18.785999999999998, 18.787, 18.787999999999997], calculator=SinglePointCalculator(...))


  0%|          | 0/839 [00:00<?, ?it/s]

Atoms(symbols='La2', pbc=True, cell=[18.476, 18.477, 18.477999999999998], calculator=SinglePointCalculator(...))


  0%|          | 0/823 [00:00<?, ?it/s]

Atoms(symbols='Ce2', pbc=True, cell=[17.855999999999998, 17.857, 17.857999999999997], calculator=SinglePointCalculator(...))


  0%|          | 0/792 [00:00<?, ?it/s]

Atoms(symbols='Pr2', pbc=True, cell=[18.104, 18.105, 18.105999999999998], calculator=SinglePointCalculator(...))


  0%|          | 0/805 [00:00<?, ?it/s]

Atoms(symbols='Nd2', pbc=True, cell=[18.290000000000003, 18.291000000000004, 18.292], calculator=SinglePointCalculator(...))


  0%|          | 0/814 [00:00<?, ?it/s]

Atoms(symbols='Pm2', pbc=True, cell=[12.0, 12.001, 12.002], calculator=SinglePointCalculator(...))


  0%|          | 0/500 [00:00<?, ?it/s]

Atoms(symbols='Sm2', pbc=True, cell=[17.98, 17.981, 17.982], calculator=SinglePointCalculator(...))


  0%|          | 0/800 [00:00<?, ?it/s]

Atoms(symbols='Eu2', pbc=True, cell=[17.794, 17.795, 17.796], calculator=SinglePointCalculator(...))


  0%|          | 0/790 [00:00<?, ?it/s]

Atoms(symbols='Gd2', pbc=True, cell=[17.546, 17.547, 17.548], calculator=SinglePointCalculator(...))


  0%|          | 0/779 [00:00<?, ?it/s]

Atoms(symbols='Tb2', pbc=True, cell=[17.298000000000002, 17.299000000000003, 17.3], calculator=SinglePointCalculator(...))


  0%|          | 0/767 [00:00<?, ?it/s]

Atoms(symbols='Dy2', pbc=True, cell=[17.794, 17.795, 17.796], calculator=SinglePointCalculator(...))


  0%|          | 0/793 [00:00<?, ?it/s]

Atoms(symbols='Ho2', pbc=True, cell=[17.422, 17.423000000000002, 17.424], calculator=SinglePointCalculator(...))


  0%|          | 0/775 [00:00<?, ?it/s]

Atoms(symbols='Er2', pbc=True, cell=[17.546, 17.547, 17.548], calculator=SinglePointCalculator(...))


  0%|          | 0/782 [00:00<?, ?it/s]

Atoms(symbols='Tm2', pbc=True, cell=[17.298000000000002, 17.299000000000003, 17.3], calculator=SinglePointCalculator(...))


  0%|          | 0/769 [00:00<?, ?it/s]

Atoms(symbols='Yb2', pbc=True, cell=[17.36, 17.361, 17.362], calculator=SinglePointCalculator(...))


  0%|          | 0/774 [00:00<?, ?it/s]

Atoms(symbols='Lu2', pbc=True, cell=[16.988000000000003, 16.989000000000004, 16.990000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/755 [00:00<?, ?it/s]

Atoms(symbols='Hf2', pbc=True, cell=[16.306, 16.307000000000002, 16.308], calculator=SinglePointCalculator(...))


  0%|          | 0/727 [00:00<?, ?it/s]

Atoms(symbols='Ta2', pbc=True, cell=[15.686, 15.687, 15.688], calculator=SinglePointCalculator(...))


  0%|          | 0/699 [00:00<?, ?it/s]

Atoms(symbols='W2', pbc=True, cell=[15.934, 15.934999999999999, 15.936], calculator=SinglePointCalculator(...))


  0%|          | 0/715 [00:00<?, ?it/s]

Atoms(symbols='Re2', pbc=True, cell=[15.438000000000002, 15.439000000000002, 15.440000000000003], calculator=SinglePointCalculator(...))


  0%|          | 0/696 [00:00<?, ?it/s]

Atoms(symbols='Os2', pbc=True, cell=[15.376, 15.376999999999999, 15.378], calculator=SinglePointCalculator(...))


  0%|          | 0/696 [00:00<?, ?it/s]

Atoms(symbols='Ir2', pbc=True, cell=[14.942000000000002, 14.943000000000001, 14.944000000000003], calculator=SinglePointCalculator(...))


  0%|          | 0/676 [00:00<?, ?it/s]

Atoms(symbols='Pt2', pbc=True, cell=[14.198, 14.199, 14.200000000000001], calculator=SinglePointCalculator(...))


  0%|          | 0/641 [00:00<?, ?it/s]

Atoms(symbols='Au2', pbc=True, cell=[14.383999999999999, 14.384999999999998, 14.386], calculator=SinglePointCalculator(...))


  0%|          | 0/651 [00:00<?, ?it/s]

Atoms(symbols='Hg2', pbc=True, cell=[15.190000000000001, 15.191, 15.192000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/693 [00:00<?, ?it/s]

Atoms(symbols='Tl2', pbc=True, cell=[15.314000000000002, 15.315000000000001, 15.316000000000003], calculator=SinglePointCalculator(...))


  0%|          | 0/693 [00:00<?, ?it/s]

Atoms(symbols='Pb2', pbc=True, cell=[16.12, 16.121000000000002, 16.122], calculator=SinglePointCalculator(...))


  0%|          | 0/733 [00:00<?, ?it/s]

Atoms(symbols='Bi2', pbc=True, cell=[15.748000000000001, 15.749, 15.750000000000002], calculator=SinglePointCalculator(...))


  0%|          | 0/713 [00:00<?, ?it/s]

Atoms(symbols='Po2', pbc=True, cell=[12.0, 12.001, 12.002], calculator=SinglePointCalculator(...))


  0%|          | 0/530 [00:00<?, ?it/s]

Atoms(symbols='At2', pbc=True, cell=[12.0, 12.001, 12.002], calculator=SinglePointCalculator(...))


  0%|          | 0/525 [00:00<?, ?it/s]

Atoms(symbols='Rn2', pbc=True, cell=[12.0, 12.001, 12.002], calculator=SinglePointCalculator(...))


  0%|          | 0/525 [00:00<?, ?it/s]

Atoms(symbols='Fr2', pbc=True, cell=[12.0, 12.001, 12.002], calculator=SinglePointCalculator(...))


  0%|          | 0/500 [00:00<?, ?it/s]

Atoms(symbols='Ra2', pbc=True, cell=[12.0, 12.001, 12.002], calculator=SinglePointCalculator(...))


  0%|          | 0/500 [00:00<?, ?it/s]

Atoms(symbols='Ac2', pbc=True, cell=[17.36, 17.361, 17.362], calculator=SinglePointCalculator(...))


  0%|          | 0/768 [00:00<?, ?it/s]

Atoms(symbols='Th2', pbc=True, cell=[18.166, 18.167, 18.168])


  0%|          | 0/808 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Pa2', pbc=True, cell=[17.855999999999998, 17.857, 17.857999999999997])


  0%|          | 0/792 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='U2', pbc=True, cell=[16.802, 16.803, 16.804])


../aten/src/ATen/native/cuda/Indexing.cu:1237: indexSelectSmallIndex: block: [0,0,0], thread: [64,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1237: indexSelectSmallIndex: block: [0,0,0], thread: [65,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1237: indexSelectSmallIndex: block: [0,0,0], thread: [66,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1237: indexSelectSmallIndex: block: [0,0,0], thread: [67,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1237: indexSelectSmallIndex: block: [0,0,0], thread: [68,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1237: indexSelectSmallIndex: block: [0,0,0], thread: [69,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1237: indexSelectSmallIndex: block: [0,0,0], thread: 

  0%|          | 0/742 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Np2', pbc=True, cell=[17.483999999999998, 17.485, 17.485999999999997])


  0%|          | 0/779 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Pu2', pbc=True, cell=[17.422, 17.423000000000002, 17.424])


  0%|          | 0/777 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Am2', pbc=True, cell=[17.546, 17.547, 17.548])


  0%|          | 0/787 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Cm2', pbc=True, cell=[18.91, 18.911, 18.912])


  0%|          | 0/860 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Bk2', pbc=True, cell=[21.08, 21.081, 21.081999999999997])


  0%|          | 0/1044 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Cf2', pbc=True, cell=[18.91, 18.911, 18.912])


  0%|          | 0/935 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Es2', pbc=True, cell=[16.740000000000002, 16.741000000000003, 16.742])


  0%|          | 0/827 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Fm2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Md2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='No2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Lr2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Rf2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Db2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Sg2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Bh2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Hs2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Mt2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Ds2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Rg2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Cn2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Nh2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Fl2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Mc2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Lv2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Ts2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Atoms(symbols='Og2', pbc=True, cell=[12.0, 12.001, 12.002])


  0%|          | 0/590 [00:00<?, ?it/s]

CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



In [4]:


df = pd.DataFrame(columns=["name", "method", "R", "E", "F", "S^2", "spearman-repulsion", "spearman-attraction"])

for symbol in tqdm(chemical_symbols[1:]):

    da = symbol + symbol

    out_dir = Path(REGISTRY[model_name]["family"]) / da

    traj_fpath = out_dir / f"{model_name}.extxyz"


    if traj_fpath.exists():
        traj = read(traj_fpath, index=":")
    else:
        continue

    Rs, Es, Fs, S2s = [], [], [], []
    for atoms in traj:

        vec = atoms.positions[1] - atoms.positions[0]
        r = np.linalg.norm(vec)
        e = atoms.get_potential_energy()
        f = np.inner(vec/r, atoms.get_forces()[1])
        # s2 = np.mean(np.power(atoms.get_magnetic_moments(), 2))

        Rs.append(r)
        Es.append(e)
        Fs.append(f)
        # S2s.append(s2)

    rs = np.array(Rs)
    es = np.array(Es)
    fs = np.array(Fs)

    indices = np.argsort(rs)[::-1]
    rs = rs[indices]
    es = es[indices]
    fs = fs[indices]

    iminf = np.argmin(fs)
    imine = np.argmin(es)

    data = {
        "name": da,
        "method": model_name,
        "R": Rs,
        "E": Es,
        "F": Fs,
        "S^2": S2s,
        "spearman-descending-force": stats.spearmanr(rs[iminf:], fs[iminf:]).statistic,
        "spearman-ascending-force": stats.spearmanr(rs[:iminf], fs[:iminf]).statistic,
        "spearman-repulsion-energy": stats.spearmanr(rs[imine:], es[imine:]).statistic,
        "spearman-attraction-energy": stats.spearmanr(rs[:imine], es[:imine]).statistic,
    }

    df = pd.concat([df, pd.DataFrame([data])], ignore_index=True)


json_fpath = Path(REGISTRY[model_name]["family"]) / "homonuclear-diatomics.json"

if json_fpath.exists():
    df0 = pd.read_json(json_fpath)
    df = pd.concat([df0, df], ignore_index=True)
    df.drop_duplicates(inplace=True, subset=["name", "method"])

df.to_json(json_fpath, orient="records")

  0%|          | 0/118 [00:00<?, ?it/s]

In [None]:
df