diff --git a/janus_core/single_point.py b/janus_core/single_point.py index 231ab59f..556e2f29 100644 --- a/janus_core/single_point.py +++ b/janus_core/single_point.py @@ -4,7 +4,7 @@ from typing import Any, Literal, Optional, Union from ase.io import read, write -from numpy import ndarray +from numpy import isfinite, ndarray from janus_core.mlip_calculators import architectures, choose_calculator, devices @@ -170,6 +170,25 @@ def _get_stress(self) -> Union[ndarray, list[ndarray]]: return self.struct.get_stress() + def _clean_results(self): + """Remove results with NaN or inf values from calc.results dictionary.""" + + if isinstance(self.struct, list): + for image in self.struct: + rm_keys = [] + for prop in image.calc.results: + if not isfinite(image.calc.results[prop]).all(): + rm_keys.append(prop) + for prop in rm_keys: + image.calc.results.pop(prop) + else: + rm_keys = [] + for prop in self.struct.calc.results: + if not isfinite(self.struct.calc.results[prop]).all(): + rm_keys.append(prop) + for prop in rm_keys: + self.struct.calc.results.pop(prop) + def run_single_point( self, properties: Optional[Union[str, list[str]]] = None, @@ -209,6 +228,8 @@ def run_single_point( if "stress" in properties or len(properties) == 0: results["stress"] = self._get_stress() + self._clean_results() + if write_kwargs: write(images=self.struct, **write_kwargs) diff --git a/tests/test_single_point.py b/tests/test_single_point.py index dcaf1333..62dd489d 100644 --- a/tests/test_single_point.py +++ b/tests/test_single_point.py @@ -3,6 +3,7 @@ from pathlib import Path from ase.io import read +from numpy import isfinite import pytest from janus_core.single_point import SinglePoint @@ -104,3 +105,23 @@ def test_single_point_write_missing(): ) with pytest.raises(ValueError): single_point.run_single_point(write_kwargs={"file": "file.xyz"}) + + +def test_single_point_write_nan(tmp_path): + """Test non-finite singlepoint results removed.""" + data_path = DATA_PATH / "H2O.cif" + results_path = tmp_path / "H2O.xyz" + single_point = SinglePoint( + structure=data_path, + architecture="mace", + calc_kwargs={"model_paths": MODEL_PATH}, + ) + + assert isfinite(single_point.run_single_point("energy")["energy"]).all() + assert not isfinite(single_point.run_single_point("stress")["stress"]).all() + + single_point.run_single_point(write_kwargs={"filename": results_path}) + atoms = read(results_path) + assert atoms.get_potential_energy() is not None + assert "forces" in atoms.calc.results + assert "stress" not in atoms.calc.results