Skip to content

Commit

Permalink
Add hessian (stfc#325)
Browse files Browse the repository at this point in the history
* Add hessian calculation

* Validate if hessian can be calculated
  • Loading branch information
ElliottKasoar authored Oct 11, 2024
1 parent c91d428 commit b52b3c3
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 41 deletions.
33 changes: 3 additions & 30 deletions janus_core/calculations/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from typing import Any, Optional

from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.mixing import SumCalculator
import numpy as np

from janus_core.calculations.base import BaseCalculation
Expand All @@ -17,7 +15,7 @@
MaybeSequence,
PathLike,
)
from janus_core.helpers.utils import none_to_dict, output_structs
from janus_core.helpers.utils import check_calculator, none_to_dict, output_structs


class Descriptors(BaseCalculation):
Expand Down Expand Up @@ -166,42 +164,17 @@ def __init__(
raise ValueError("Please attach a calculator to `struct`.")

if isinstance(self.struct, Atoms):
self._check_calculator(self.struct.calc)
check_calculator(self.struct.calc, "get_descriptors")
if isinstance(self.struct, Sequence):
for image in self.struct:
self._check_calculator(image.calc)
check_calculator(image.calc, "get_descriptors")

# Set output file
self.write_kwargs.setdefault("filename", None)
self.write_kwargs["filename"] = self._build_filename(
"descriptors.extxyz", filename=self.write_kwargs["filename"]
).absolute()

@staticmethod
def _check_calculator(calc: Calculator) -> None:
"""
Ensure calculator has ability to calculate descriptors.
Parameters
----------
calc : Calculator
ASE Calculator to calculate descriptors.
"""
# If dispersion added to MLIP calculator, use MLIP calculator for descriptors
if isinstance(calc, SumCalculator):
if (
len(calc.mixer.calcs) == 2
and calc.mixer.calcs[1].name == "TorchDFTD3Calculator"
and hasattr(calc.mixer.calcs[0], "get_descriptors")
):
calc.get_descriptors = calc.mixer.calcs[0].get_descriptors

if not hasattr(calc, "get_descriptors") or not callable(calc.get_descriptors):
raise NotImplementedError(
"The attached calculator does not currently support calculating "
"descriptors"
)

def run(self) -> None:
"""Calculate descriptors for structure(s)."""
if self.logger:
Expand Down
59 changes: 55 additions & 4 deletions janus_core/calculations/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
PathLike,
Properties,
)
from janus_core.helpers.utils import none_to_dict, output_structs
from janus_core.helpers.utils import check_calculator, none_to_dict, output_structs


class SinglePoint(BaseCalculation):
Expand Down Expand Up @@ -133,7 +133,6 @@ def __init__(
"""
(read_kwargs, write_kwargs) = none_to_dict((read_kwargs, write_kwargs))

self.properties = properties
self.write_results = write_results
self.write_kwargs = write_kwargs
self.log_kwargs = log_kwargs
Expand All @@ -158,6 +157,9 @@ def __init__(
tracker_kwargs=tracker_kwargs,
)

# Properties validated using calculator
self.properties = properties

# Set output file
self.write_kwargs.setdefault("filename", None)
self.write_kwargs["filename"] = self._build_filename(
Expand Down Expand Up @@ -198,9 +200,17 @@ def properties(self, value: MaybeSequence[Properties]) -> None:
f"Property '{prop}' cannot currently be calculated."
)

# If none specified, get all valid properties
# If none specified, get energy, forces and stress
if not value:
value = get_args(Properties)
value = ("energy", "forces", "stress")

# Validate properties
if "hessian" in value:
if isinstance(self.struct, Sequence):
for image in self.struct:
check_calculator(image.calc, "get_hessian")
else:
check_calculator(self.struct.calc, "get_hessian")

self._properties = value

Expand Down Expand Up @@ -246,6 +256,45 @@ def _get_stress(self) -> MaybeList[ndarray]:

return self.struct.get_stress()

def _calc_hessian(self, struct: Atoms) -> ndarray:
"""
Calculate analytical Hessian for a given structure.
Parameters
----------
struct : Atoms
Structure to calculate Hessian for.
Returns
-------
ndarray
Analytical Hessian.
"""
if "arch" in struct.calc.parameters:
arch = struct.calc.parameters["arch"]
label = f"{arch}_"
else:
label = ""

# Calculate hessian
hessian = struct.calc.get_hessian(struct)
struct.info[f"{label}hessian"] = hessian
return hessian

def _get_hessian(self) -> MaybeList[ndarray]:
"""
Calculate hessian using MLIP.
Returns
-------
MaybeList[ndarray]
Hessian of structure(s).
"""
if isinstance(self.struct, Sequence):
return [self._calc_hessian(struct) for struct in self.struct]

return self._calc_hessian(self.struct)

def run(self) -> CalcResults:
"""
Run single point calculations.
Expand All @@ -267,6 +316,8 @@ def run(self) -> CalcResults:
self.results["forces"] = self._get_forces()
if "stress" in self.properties:
self.results["stress"] = self._get_stress()
if "hessian" in self.properties:
self.results["hessian"] = self._get_hessian()

if self.logger:
emissions = self.tracker.stop_task().emissions
Expand Down
4 changes: 2 additions & 2 deletions janus_core/cli/singlepoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def singlepoint(
Device to run model on. Default is "cpu".
model_path : Optional[str]
Path to MLIP model. Default is `None`.
properties : Optional[str]
Physical properties to calculate. Default is "energy".
properties : Optional[list[str]]
Physical properties to calculate. Default is ("energy", "forces", "stress").
out : Optional[Path]
Path to save structure with calculated results. Default is inferred from name
of the structure file.
Expand Down
2 changes: 1 addition & 1 deletion janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class CorrelationKwargs(TypedDict, total=True):
]
Devices = Literal["cpu", "cuda", "mps", "xpu"]
Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"]
Properties = Literal["energy", "stress", "forces"]
Properties = Literal["energy", "stress", "forces", "hessian"]
PhononCalcs = Literal["bands", "dos", "pdos", "thermal"]


Expand Down
32 changes: 32 additions & 0 deletions janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from typing import Any, Literal, Optional, TextIO, Union, get_args

from ase import Atoms
from ase.calculators.calculator import Calculator
from ase.calculators.mixing import SumCalculator
from ase.io import read, write
from ase.io.formats import filetype
from ase.spacegroup.symmetrize import refine_symmetry
Expand Down Expand Up @@ -739,3 +741,33 @@ def track_progress(sequence: Union[Sequence, Iterable], description: str) -> Ite

with progress:
yield from progress.track(sequence, description=description)


def check_calculator(calc: Calculator, attribute: str) -> None:
"""
Ensure calculator has ability to calculate properties.
If the calculator is a SumCalculator that inlcudes the TorchDFTD3Calculator, this
also sets the relevant function so that the MLIP component of the calculator is
used for properties unrelated to dispersion.
Parameters
----------
calc : Calculator
ASE Calculator to check.
attribute : str
Attribute to check calculator for.
"""
# If dispersion added to MLIP calculator, use only MLIP calculator for calculation
if isinstance(calc, SumCalculator):
if (
len(calc.mixer.calcs) == 2
and calc.mixer.calcs[1].name == "TorchDFTD3Calculator"
and hasattr(calc.mixer.calcs[0], attribute)
):
setattr(calc, attribute, getattr(calc.mixer.calcs[0], attribute))

if not hasattr(calc, attribute) or not callable(getattr(calc, attribute)):
raise NotImplementedError(
f"The attached calculator does not currently support {attribute}"
)
42 changes: 42 additions & 0 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,45 @@ def test_logging(tmp_path):
assert log_file.exists()
assert "emissions" in single_point.struct.info
assert single_point.struct.info["emissions"] > 0


def test_hessian():
"""Test Hessian."""
sp = SinglePoint(
calc_kwargs={"model": MACE_PATH},
struct_path=DATA_PATH / "NaCl.cif",
arch="mace_mp",
properties="hessian",
)
results = sp.run()
assert "hessian" in results
assert results["hessian"].shape == (24, 8, 3)
assert "mace_mp_hessian" in sp.struct.info


def test_hessian_traj():
"""Test calculating Hessian for trajectory."""
sp = SinglePoint(
calc_kwargs={"model": MACE_PATH},
struct_path=DATA_PATH / "benzene-traj.xyz",
arch="mace_mp",
properties="hessian",
)
results = sp.run()
assert "hessian" in results
assert len(results["hessian"]) == 2
assert results["hessian"][0].shape == (36, 12, 3)
assert results["hessian"][1].shape == (36, 12, 3)
assert "mace_mp_hessian" in sp.struct[0].info
assert "mace_mp_hessian" in sp.struct[1].info


@pytest.mark.parametrize("struct", ["NaCl.cif", "benzene-traj.xyz"])
def test_hessian_not_implemented(struct):
"""Test unimplemented Hessian."""
with pytest.raises(NotImplementedError):
SinglePoint(
struct_path=DATA_PATH / struct,
arch="chgnet",
properties="hessian",
)
38 changes: 37 additions & 1 deletion tests/test_singlepoint_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_properties(tmp_path):
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

# Check energy is can be calculated successfully
# Check energy can be calculated successfully
result = runner.invoke(
app,
[
Expand Down Expand Up @@ -355,3 +355,39 @@ def test_write_cif(tmp_path):
assert result.exit_code == 0
atoms = read(results_path)
assert isinstance(atoms, Atoms)


def test_hessian(tmp_path):
"""Test Hessian calculation."""
results_path = tmp_path / "NaCl-results.extxyz"
log_path = tmp_path / "test.log"
summary_path = tmp_path / "summary.yml"

# Check Hessian can be calculated successfully
result = runner.invoke(
app,
[
"singlepoint",
"--struct",
DATA_PATH / "NaCl.cif",
"--properties",
"hessian",
"--properties",
"energy",
"--out",
results_path,
"--calc-kwargs",
"{'dispersion': True}",
"--log",
log_path,
"--summary",
summary_path,
],
)
assert result.exit_code == 0

atoms = read(results_path)
assert "mace_mp_energy" in atoms.info
assert "mace_mp_hessian" in atoms.info
assert "mace_stress" not in atoms.info
assert atoms.info["mace_mp_hessian"].shape == (24, 8, 3)
8 changes: 5 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
"""Test utility functions."""

from pathlib import Path
from typing import get_args

from ase import Atoms
from ase.io import read
import pytest

from janus_core.cli.utils import dict_paths_to_strs, dict_remove_hyphens
from janus_core.helpers.janus_types import Properties
from janus_core.helpers.mlip_calculators import choose_calculator
from janus_core.helpers.utils import none_to_dict, output_structs

Expand Down Expand Up @@ -73,7 +71,11 @@ def test_output_structs(
struct = read(DATA_PATH)
struct.calc = choose_calculator(arch=arch)

results_keys = set(get_args(Properties)) if not properties else set(properties)
if properties:
results_keys = set(properties)
else:
results_keys = {"energy", "forces", "stress"}

label_keys = {f"{arch}_{key}" for key in results_keys}

write_kwargs = {}
Expand Down

0 comments on commit b52b3c3

Please sign in to comment.