cyrusyc commited on
Commit
97d8be3
·
1 Parent(s): da3f047

improve deps and imports

Browse files
mlip_arena/models/__init__.py CHANGED
@@ -42,7 +42,7 @@ for model, metadata in REGISTRY.items():
42
  f"{__package__}.{metadata['module']}.{metadata['family']}"
43
  )
44
  MLIPMap[model] = getattr(module, metadata["class"])
45
- except (ModuleNotFoundError, AttributeError, ValueError, ImportError) as e:
46
  logger.warning(e)
47
  continue
48
 
 
42
  f"{__package__}.{metadata['module']}.{metadata['family']}"
43
  )
44
  MLIPMap[model] = getattr(module, metadata["class"])
45
+ except (ModuleNotFoundError, AttributeError, ValueError, ImportError, Exception) as e:
46
  logger.warning(e)
47
  continue
48
 
mlip_arena/tasks/__init__.py CHANGED
@@ -6,6 +6,13 @@ from huggingface_hub import HfApi, HfFileSystem, hf_hub_download
6
  # from mlip_arena.models import MLIP
7
  # from mlip_arena.models import REGISTRY as MODEL_REGISTRY
8
 
 
 
 
 
 
 
 
9
  try:
10
  from .elasticity import run as ELASTICITY
11
  from .eos import run as EOS
@@ -16,8 +23,9 @@ try:
16
  from .phonon import run as PHONON
17
 
18
  __all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY", "PHONON"]
19
- except ImportError:
20
- pass
 
21
 
22
  with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
23
  REGISTRY = yaml.safe_load(f)
 
6
  # from mlip_arena.models import MLIP
7
  # from mlip_arena.models import REGISTRY as MODEL_REGISTRY
8
 
9
+ try:
10
+ from prefect.logging import get_run_logger
11
+
12
+ logger = get_run_logger()
13
+ except (ImportError, RuntimeError):
14
+ from loguru import logger
15
+
16
  try:
17
  from .elasticity import run as ELASTICITY
18
  from .eos import run as EOS
 
23
  from .phonon import run as PHONON
24
 
25
  __all__ = ["OPT", "EOS", "MD", "NEB", "NEB_FROM_ENDPOINTS", "ELASTICITY", "PHONON"]
26
+ except (ImportError, TypeError) as e:
27
+ logger.warning(e)
28
+
29
 
30
  with open(Path(__file__).parent / "registry.yaml", encoding="utf-8") as f:
31
  REGISTRY = yaml.safe_load(f)
mlip_arena/tasks/utils.py CHANGED
@@ -8,7 +8,6 @@ import torch
8
  from ase import units
9
  from ase.calculators.calculator import BaseCalculator
10
  from ase.calculators.mixing import SumCalculator
11
- from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
12
 
13
  from mlip_arena.models import MLIPEnum
14
 
@@ -102,6 +101,13 @@ def get_calculator(
102
  dispersion_kwargs.update({"device": device})
103
 
104
  if dispersion:
 
 
 
 
 
 
 
105
  disp_calc = TorchDFTD3Calculator(
106
  **dispersion_kwargs,
107
  )
@@ -112,5 +118,5 @@ def get_calculator(
112
  if dispersion_kwargs:
113
  logger.info(pformat(dispersion_kwargs))
114
 
115
- assert isinstance(calc, BaseCalculator)
116
  return calc
 
8
  from ase import units
9
  from ase.calculators.calculator import BaseCalculator
10
  from ase.calculators.mixing import SumCalculator
 
11
 
12
  from mlip_arena.models import MLIPEnum
13
 
 
101
  dispersion_kwargs.update({"device": device})
102
 
103
  if dispersion:
104
+ try:
105
+ from torch_dftd.torch_dftd3_calculator import TorchDFTD3Calculator
106
+ except ImportError as e:
107
+ raise ImportError(
108
+ "torch_dftd is required for dispersion but is not installed."
109
+ ) from e
110
+
111
  disp_calc = TorchDFTD3Calculator(
112
  **dispersion_kwargs,
113
  )
 
118
  if dispersion_kwargs:
119
  logger.info(pformat(dispersion_kwargs))
120
 
121
+ assert isinstance(calc, BaseCalculator)
122
  return calc
pyproject.toml CHANGED
@@ -10,7 +10,7 @@ authors=[
10
  ]
11
  description="Fair and transparent benchmark of machine learning interatomic potentials (MLIPs), beyond error-based regression metrics"
12
  readme=".github/README.md"
13
- requires-python=">=3.10"
14
  keywords=[
15
  "pytorch",
16
  "machine-learning-interatomic-potentials",
@@ -27,6 +27,7 @@ classifiers=[
27
  "Programming Language :: Python :: 3 :: Only",
28
  ]
29
  dependencies=[
 
30
  "ase",
31
  "pymatgen",
32
  "torch",
 
10
  ]
11
  description="Fair and transparent benchmark of machine learning interatomic potentials (MLIPs), beyond error-based regression metrics"
12
  readme=".github/README.md"
13
+ requires-python=">=3.9"
14
  keywords=[
15
  "pytorch",
16
  "machine-learning-interatomic-potentials",
 
27
  "Programming Language :: Python :: 3 :: Only",
28
  ]
29
  dependencies=[
30
+ "loguru",
31
  "ase",
32
  "pymatgen",
33
  "torch",