Skip to content

Commit

Permalink
Clean calculated results
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Mar 6, 2024
1 parent 508e659 commit 53fd1b5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
23 changes: 22 additions & 1 deletion janus_core/single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Literal, Optional, Union

from ase.io import read, write
from numpy import ndarray
from numpy import isfinite, ndarray

from janus_core.mlip_calculators import architectures, choose_calculator, devices

Expand Down Expand Up @@ -170,6 +170,25 @@ def _get_stress(self) -> Union[ndarray, list[ndarray]]:

return self.struct.get_stress()

def _clean_results(self):
"""Remove results with NaN or inf values from calc.results dictionary."""

if isinstance(self.struct, list):
for image in self.struct:
rm_keys = []
for prop in image.calc.results:
if not isfinite(image.calc.results[prop]).all():
rm_keys.append(prop)
for prop in rm_keys:
image.calc.results.pop(prop)
else:
rm_keys = []
for prop in self.struct.calc.results:
if not isfinite(self.struct.calc.results[prop]).all():
rm_keys.append(prop)
for prop in rm_keys:
self.struct.calc.results.pop(prop)

def run_single_point(
self,
properties: Optional[Union[str, list[str]]] = None,
Expand Down Expand Up @@ -209,6 +228,8 @@ def run_single_point(
if "stress" in properties or len(properties) == 0:
results["stress"] = self._get_stress()

self._clean_results()

if write_kwargs:
write(images=self.struct, **write_kwargs)

Expand Down
21 changes: 21 additions & 0 deletions tests/test_single_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

from ase.io import read
from numpy import isfinite
import pytest

from janus_core.single_point import SinglePoint
Expand Down Expand Up @@ -104,3 +105,23 @@ def test_single_point_write_missing():
)
with pytest.raises(ValueError):
single_point.run_single_point(write_kwargs={"file": "file.xyz"})


def test_single_point_write_nan(tmp_path):
"""Test non-finite singlepoint results removed."""
data_path = DATA_PATH / "H2O.cif"
results_path = tmp_path / "H2O.xyz"
single_point = SinglePoint(
structure=data_path,
architecture="mace",
calc_kwargs={"model_paths": MODEL_PATH},
)

assert isfinite(single_point.run_single_point("energy")["energy"]).all()
assert not isfinite(single_point.run_single_point("stress")["stress"]).all()

single_point.run_single_point(write_kwargs={"filename": results_path})
atoms = read(results_path)
assert atoms.get_potential_energy() is not None
assert "forces" in atoms.calc.results
assert "stress" not in atoms.calc.results

0 comments on commit 53fd1b5

Please sign in to comment.