In [1]:
import os
from mp_api.client import MPRester
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
from prefect import task, flow
from prefect.task_runners import ThreadPoolTaskRunner
from prefect_dask import DaskTaskRunner
from pymatgen.core.structure import Structure
from dotenv import load_dotenv
from ase import Atoms
from ase.io import write, read
from pathlib import Path
import pandas as pd
from prefect.futures import wait

from mlip_arena.tasks.eos.run import fit as EOS
from mlip_arena.models.utils import REGISTRY, MLIPEnum

load_dotenv()

MP_API_KEY = os.environ.get("MP_API_KEY", None)

In [2]:

with MPRester(MP_API_KEY) as mpr:
    print("MP Database version:", mpr.get_database_version())

    summary_docs = mpr.materials.summary.search(
        num_elements=(1, 2),
        is_stable=True,
        fields=["material_id", "structure", "formula_pretty"]
    )


MP Database version: 2023.11.1


Retrieving SummaryDoc documents:   0%|          | 0/5135 [00:00<?, ?it/s]

In [3]:

atoms_list = []

for doc in summary_docs:

    structure = doc.structure
    assert isinstance(structure, Structure)

    atoms = structure.to_ase_atoms()

    atoms_list.append(atoms)


In [4]:
write("all.extxyz", atoms_list)

In [2]:
atoms_list = read("all.extxyz", index=':')

In [3]:
nodes_per_alloc = 1
gpus_per_alloc = 4
ntasks = 1

cluster_kwargs = {
    "cores": 1,
    "memory": "64 GB",
    "shebang": "#!/bin/bash",
    "account": "matgen",
    "walltime": "00:30:00",
    "job_mem": "0",
    "job_script_prologue": [
        "source ~/.bashrc",
        "module load python",
        "source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena",
    ],
    "job_directives_skip": ["-n", "--cpus-per-task", "-J"],
    "job_extra_directives": [f"-N {nodes_per_alloc}", f"-G {gpus_per_alloc}", "-q debug", "-C gpu", "-J eos"],
}
cluster = SLURMCluster(**cluster_kwargs)

print(cluster.job_script())
cluster.adapt(minimum_jobs=2, maximum_jobs=2)
client = Client(cluster)


#!/bin/bash

#SBATCH -A matgen
#SBATCH --mem=0
#SBATCH -t 00:30:00
#SBATCH -N 1
#SBATCH -G 4
#SBATCH -q debug
#SBATCH -C gpu
#SBATCH -J eos
source ~/.bashrc
module load python
source activate /pscratch/sd/c/cyrusyc/.conda/mlip-arena
/pscratch/sd/c/cyrusyc/.conda/mlip-arena/bin/python -m distributed.cli.dask_worker tcp://128.55.64.49:36289 --name dummy-name --nthreads 1 --memory-limit 59.60GiB --nanny --death-timeout 60



In [4]:
from prefect.concurrency.sync import concurrency
from prefect.runtime import flow_run, task_run

def postprocess(output, model: str, formula: str):
    row = {
        "formula": formula,
        "method": model,
        "volumes": output["eos"]["volumes"],
        "energies": output["eos"]["energies"],
        "K": output["K"],
    }

    fpath = Path(REGISTRY[model]["family"]) / f"{model}.parquet"

    if not fpath.exists():
        fpath.parent.mkdir(parents=True, exist_ok=True)
        df = pd.DataFrame([row])  # Convert the dictionary to a DataFrame with a list
    else:
        df = pd.read_parquet(fpath)
        new_row = pd.DataFrame([row])  # Convert dictionary to DataFrame with a list
        df = pd.concat([df, new_row], ignore_index=True)

    df.drop_duplicates(subset=["formula", "method"], keep='last', inplace=True)
    df.to_parquet(fpath)



task_runner = DaskTaskRunner(address=client.scheduler.address)
EOS = EOS.with_options(
    # task_runner=task_runner, 
    log_prints=True,
    timeout_seconds=120, 
    # result_storage=None
)

from prefect import get_client

async with get_client() as client:
    limit_id = await client.create_concurrency_limit(
        tag="bottleneck", 
        concurrency_limit=2
    )

def generate_task_run_name():
    task_name = task_run.task_name

    parameters = task_run.parameters

    atoms = parameters["atoms"]
    
    return f"{task_name}: {atoms.get_chemical_formula()}"

@task(task_run_name=generate_task_run_name, tags=["bottleneck"], timeout_seconds=150)
def fit_one(atoms: Atoms, model: str):
    
    eos = EOS(
        atoms=atoms,
        calculator_name=model,
        calculator_kwargs={},
        device=None,
        optimizer="QuasiNewton",
        optimizer_kwargs=None,
        filter="FrechetCell",
        filter_kwargs=None,
        criterion=dict(
            fmax=0.1,
        ),
        max_abs_strain=0.1,
        npoints=7,
    )
    if isinstance(eos, dict):
        postprocess(output=eos, model=model, formula=atoms.get_chemical_formula())
        eos["method"] = model
    
    return eos
    
#https://docs-3.prefect.io/3.0/develop/task-runners#use-multiple-task-runners
# @flow(task_runner=ThreadPoolTaskRunner(max_workers=50), log_prints=True)
@flow(task_runner=task_runner, log_prints=True)
def fit_all(atoms_list: list[Atoms]):
    
    futures = []
    for atoms in atoms_list:
        futures_per_atoms = []
        for model in MLIPEnum:
            
            # with concurrency("bottleneck", occupy=2):
            future = fit_one.submit(atoms, model.name)
            # if not futures_per_atoms:
            #     if not futures:
            #         future = fit_one.submit(atoms, model.name)
            #     else:
            #         future = fit_one.submit(atoms, model.name, wait_for=[futures[-1]])                    
            # else:
            #     future = fit_one.submit(atoms, model.name, wait_for=[future])
            futures_per_atoms.append(future)
            
        futures.extend(futures_per_atoms)

    return [f.result() for f in futures]


# @task(task_run_name=generate_task_run_name, result_storage=None)
# def fit_one(atoms: Atoms):
    
#     outputs = []
#     for model in MLIPEnum:
#         try:
#             eos = EOS(
#                 atoms=atoms,
#                 calculator_name=model.name,
#                 calculator_kwargs={},
#                 device=None,
#                 optimizer="QuasiNewton",
#                 optimizer_kwargs=None,
#                 filter="FrechetCell",
#                 filter_kwargs=None,
#                 criterion=dict(
#                     fmax=0.1,
#                 ),
#                 max_abs_strain=0.1,
#                 npoints=7,
#             )
#             if isinstance(eos, dict):
#                 postprocess(output=eos, model=model.name, formula=atoms.get_chemical_formula())
#                 eos["method"] = model.name
#                 outputs.append(eos)
#         except:
#             continue
    
#     return outputs

# # https://orion-docs.prefect.io/latest/concepts/task-runners/#using-multiple-task-runners
# @flow(task_runner=DaskTaskRunner(address=client.scheduler.address), log_prints=True, result_storage=None)
# def fit_all(atoms_list: list[Atoms]):
    
#     futures = []
#     for atoms in atoms_list:
#         future = fit_one.submit(atoms)
#         futures.append(future)
            
#     wait(futures)
    
#     return [f.result(raise_on_failure=False) for f in futures]

In [None]:
fit_all(atoms_list)

```
Note that, because the DaskTaskRunner uses multiprocessing, calls to flows in scripts must be guarded with if __name__ == "__main__": or you will encounter warnings and errors.
```

In [9]:
# import os
# import tempfile
# import shutil
# from contextlib import contextmanager

# @contextmanager
# def twd():
    
#     pwd = os.getcwd()
#     temp_dir = tempfile.mkdtemp()
    
#     try:
#         os.chdir(temp_dir)
#         yield
#     finally:
#         os.chdir(pwd)
#         shutil.rmtree(temp_dir)

# with twd():

# fit_all(atoms_list)

In [10]:
import pandas as pd

df = pd.read_parquet('mace-mp/MACE-MP(M).parquet')

In [11]:
df

Unnamed: 0,formula,method,volumes,energies,K
1,Ac2O3,MACE-MP(M),"[82.36010147441682, 85.41047560309894, 88.4608...","[-39.47541427612305, -39.65580749511719, -39.7...",95.755459
2,Ac6In2,MACE-MP(M),"[278.3036976131417, 288.61124196918433, 298.91...","[-31.21324348449707, -31.40914535522461, -31.5...",33.370214
3,Ac6Tl2,MACE-MP(M),"[278.30267000598286, 288.6101763025008, 298.91...","[-29.572534561157227, -29.833026885986328, -30...",29.065081
4,Ac3Sn,MACE-MP(M),"[135.293532345587, 140.30440391394214, 145.315...","[-17.135194778442383, -17.228239059448242, -17...",30.622045
5,AcAg,MACE-MP(M),"[55.376437498321394, 57.4274166649259, 59.4783...","[-7.274301528930664, -7.346108913421631, -7.39...",40.212164
6,Ac4,MACE-MP(M),"[166.09086069175856, 172.2423740507126, 178.39...","[-16.326059341430664, -16.406923294067383, -16...",25.409891
7,Ac16S24,MACE-MP(M),"[1006.5670668063424, 1043.84732853991, 1081.12...","[-249.4179229736328, -250.7970733642578, -251....",61.734158
