File size: 3,703 Bytes
20d2150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c85687
20d2150
db5eef3
 
 
 
20d2150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import hashlib
import json
import os
from pathlib import Path

import ase
import gradio as gr
import huggingface_hub as hf_hub
from ase.calculators.calculator import Calculator
from ase.db.core import now
from ase.db.row import AtomsRow
from ase.io.jsonio import decode, encode


def hash_save_file(atoms: ase.Atoms, task_name, path: Path | str):
    atoms = atoms.copy()
    atoms.info["task_name"] = task_name
    atoms.write(
        Path(path)
        / f"{hashlib.md5(atoms_to_json(atoms).encode('utf-8')).hexdigest()}.traj"
    )
    return


class HFEndpointCalculator(Calculator):
    # A simple calculator that uses the Hugging Face Inference Endpoints to run

    implemented_properties = ["energy", "free_energy", "stress", "forces"]

    def __init__(
        self,
        atoms,
        endpoint_url,
        oauth_token,
        task_name,
        example=False,
        *args,
        **kwargs,
    ):
        # If we have an example structure, we don't need to check for authentication
        # Otherwise, we need to check if the user is authenticated and has gated access to the UMA models
        if not example:
            try:
                hf_hub.HfApi().auth_check(
                    repo_id="facebook/UMA", token=oauth_token.token
                )
                try:
                    hash_save_file(atoms, task_name, "/data/custom_inputs/")
                except FileNotFoundError:
                    pass
            except (hf_hub.errors.HfHubHTTPError, AttributeError):
                raise gr.Error(
                    "You need to log in to HF and have gated model access to UMA before running your own simulations!"
                )

        self.client = hf_hub.InferenceClient(
            model=endpoint_url, token=os.environ["HF_TOKEN"]
        )
        self.atoms = atoms
        self.task_name = task_name

        super().__init__(*args, **kwargs)

    def calculate(self, atoms, properties, system_changes):
        Calculator.calculate(self, atoms, properties, system_changes)

        task_name = self.task_name.lower()

        # Run inference via a post request to the endpoint
        try:
            response = self.client.post(
                json={
                    "inputs": atoms_to_json(atoms, data=atoms.info),
                    "properties": properties,
                    "system_changes": system_changes,
                    "task_name": task_name,
                }
            )
        except hf_hub.errors.BadRequestError:
            hash_save_file(atoms, task_name, "/data/custom_inputs/errors/")
            raise gr.Error(
                "Backend failure during your calculation; if you have continued issues please file an issue in the main FAIR chemistry repo (https://github.com/facebookresearch/fairchem)."
            )

        # Load the response and store the results in the calc and atoms object
        response_dict = decode(json.loads(response))
        self.results = response_dict["results"]
        atoms.info = response_dict["info"]


def atoms_to_json(atoms, data=None):
    # Similar to ase.db.jsondb

    mtime = now()

    row = AtomsRow(atoms)
    row.ctime = mtime

    dct = {}
    for key in row.__dict__:
        if key[0] == "_" or key in row._keys or key == "id":
            continue
        dct[key] = row[key]

    dct["mtime"] = mtime

    if data:
        dct["data"] = data

    constraints = row.get("constraints")
    if constraints:
        dct["constraints"] = constraints

    return encode(dct)