Skip to content

Commit

Permalink
Fix model path kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Mar 13, 2024
1 parent 035bfa5 commit 4f5bc80
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 31 deletions.
15 changes: 13 additions & 2 deletions janus_core/mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,27 +43,38 @@ def choose_calculator(
# pylint: disable=import-outside-toplevel, too-many-branches, import-error
# Optional imports handled via `architecture`. We could catch these,
# but the error message is clear if imports are missing.
if "model" in kwargs and "model_paths" in kwargs:
raise ValueError("Please specify either `model` or `model_paths`")

if architecture == "mace":
from mace import __version__
from mace.calculators import MACECalculator

# `model_paths` is keyword for path to model, so take from kwargs if specified
# Otherwise, take `model` if specified, then default to `None`, which will
# raise a ValueError
kwargs.setdefault("model_paths", kwargs.pop("model", None))
kwargs.setdefault("default_dtype", "float64")
calculator = MACECalculator(device=device, **kwargs)

elif architecture == "mace_mp":
from mace import __version__
from mace.calculators import mace_mp

# `model` is keyword for path to model, so take from kwargs if specified
# Otherwise, take `model_paths` if specified, then default to "small"
kwargs.setdefault("model", kwargs.pop("model_paths", "small"))
kwargs.setdefault("default_dtype", "float64")
kwargs["model"] = kwargs.pop("model_paths", "small")
calculator = mace_mp(**kwargs)

elif architecture == "mace_off":
from mace import __version__
from mace.calculators import mace_off

# `model` is keyword for path to model, so take from kwargs if specified
# Otherwise, take `model_paths` if specified, then default to "small"
kwargs.setdefault("model", kwargs.pop("model_paths", "small"))
kwargs.setdefault("default_dtype", "float64")
kwargs["model"] = kwargs.pop("model_paths", "small")
calculator = mace_off(**kwargs)

elif architecture == "m3gnet":
Expand Down
43 changes: 28 additions & 15 deletions tests/test_mlip_calculators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,15 @@

from janus_core.mlip_calculators import choose_calculator

MODEL_PATH = Path(__file__).parent / "models" / "mace_mp_small.model"

test_data_mace = [
(
"mace",
"cpu",
{"model_paths": Path(__file__).parent / "models" / "mace_mp_small.model"},
),
("mace", "cpu", {"model": MODEL_PATH}),
("mace", "cpu", {"model_paths": MODEL_PATH}),
("mace_off", "cpu", {}),
("mace_mp", "cpu", {}),
(
"mace_mp",
"cpu",
{"model_paths": Path(__file__).parent / "models" / "mace_mp_small.model"},
),
(
"mace_off",
"cpu",
{"model_paths": "small"},
),
("mace_mp", "cpu", {"model": MODEL_PATH}),
("mace_off", "cpu", {"model": "small"}),
]

test_data_extras = [("m3gnet", "cpu"), ("chgnet", "")]
Expand Down Expand Up @@ -51,3 +42,25 @@ def test_invalid_arch():
"""Test error raised for invalid architecture."""
with pytest.raises(ValueError):
choose_calculator(architecture="invalid")


def test_model_model_paths():
"""Test error raised if both model and model_paths are specified."""
with pytest.raises(ValueError):
choose_calculator(
architecture="mace",
model=MODEL_PATH,
model_paths=MODEL_PATH,
)
with pytest.raises(ValueError):
choose_calculator(
architecture="mace_mp",
model=MODEL_PATH,
model_paths=MODEL_PATH,
)
with pytest.raises(ValueError):
choose_calculator(
architecture="mace_off",
model=MODEL_PATH,
model_paths=MODEL_PATH,
)
28 changes: 14 additions & 14 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_potential_energy(
struct_path, expected, properties, prop_key, calc_kwargs, idx
):
"""Test single point energy using MACE calculators."""
calc_kwargs["model_paths"] = MODEL_PATH
calc_kwargs["model"] = MODEL_PATH
single_point = SinglePoint(
struct_path=struct_path, architecture="mace", calc_kwargs=calc_kwargs
)
Expand All @@ -56,7 +56,7 @@ def test_single_point_none():
single_point = SinglePoint(
struct_path=DATA_PATH / "NaCl.cif",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)

results = single_point.run_single_point()
Expand All @@ -69,7 +69,7 @@ def test_single_point_clean():
single_point = SinglePoint(
struct_path=DATA_PATH / "H2O.cif",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)

results = single_point.run_single_point()
Expand All @@ -84,7 +84,7 @@ def test_single_point_traj():
struct_path=DATA_PATH / "benzene-traj.xyz",
architecture="mace",
read_kwargs={"index": ":"},
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)

assert len(single_point.struct) == 2
Expand All @@ -102,7 +102,7 @@ def test_single_point_write():
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.arrays

Expand All @@ -122,7 +122,7 @@ def test_single_point_write_kwargs(tmp_path):
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)
assert "forces" not in single_point.struct.arrays

Expand All @@ -141,7 +141,7 @@ def test_single_point_write_nan(tmp_path):
single_point = SinglePoint(
struct_path=data_path,
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)

assert isfinite(single_point.run_single_point("energy")["energy"]).all()
Expand All @@ -162,7 +162,7 @@ def test_invalid_prop():
single_point = SinglePoint(
struct_path=DATA_PATH / "H2O.cif",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)
with pytest.raises(NotImplementedError):
single_point.run_single_point("invalid")
Expand All @@ -175,7 +175,7 @@ def test_atoms():
struct=struct,
struct_name="NaCl",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)
assert single_point.struct_name == "NaCl"
assert single_point.run_single_point("energy")["energy"] < 0
Expand All @@ -187,7 +187,7 @@ def test_default_atoms_name():
single_point = SinglePoint(
struct=struct,
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)
assert single_point.struct_name == "Cl4Na4"

Expand All @@ -198,7 +198,7 @@ def test_default_path_name():
single_point = SinglePoint(
struct_path=struct_path,
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)
assert single_point.struct_name == "NaCl"

Expand All @@ -210,7 +210,7 @@ def test_path_specify_name():
struct_path=struct_path,
struct_name="example_name",
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)
assert single_point.struct_name == "example_name"

Expand All @@ -224,7 +224,7 @@ def test_atoms_and_path():
struct=struct,
struct_path=struct_path,
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)


Expand All @@ -233,5 +233,5 @@ def test_no_atoms_or_path():
with pytest.raises(ValueError):
SinglePoint(
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
calc_kwargs={"model": MODEL_PATH},
)

0 comments on commit 4f5bc80

Please sign in to comment.