Skip to content

Commit

Permalink
Add saving singlepoint to CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Mar 5, 2024
1 parent 6f122a9 commit 508e659
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
22 changes: 19 additions & 3 deletions janus_core/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def singlepoint(
list[str],
typer.Option(
"--property",
help="Properties to calculate. If not specified, 'energy', 'forces', \
and 'stress' will be returned.",
help=(
"Properties to calculate. If not specified, 'energy', 'forces' "
"and 'stress' will be returned."
),
),
] = None,
read_kwargs: Annotated[
Expand All @@ -93,6 +95,17 @@ def singlepoint(
metavar="DICT",
),
] = None,
write_kwargs: Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
help=(
"Keyword arguments to pass to ase.io.write to save "
"results [default: {}]"
),
metavar="DICT",
),
] = None,
):
"""
Perform single point calculations.
Expand All @@ -112,9 +125,12 @@ def singlepoint(
Keyword arguments to pass to ase.io.read. Default is {}.
calc_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to the selected calculator. Default is {}.
write_kwargs : Optional[dict[str, Any]]
Keyword arguments to pass to ase.io.write to save results. Default is {}.
"""
read_kwargs = read_kwargs.value if read_kwargs else {}
calc_kwargs = calc_kwargs.value if calc_kwargs else {}
write_kwargs = write_kwargs.value if write_kwargs else {}

if not isinstance(read_kwargs, dict):
raise ValueError("read_kwargs must be a dictionary")
Expand All @@ -128,7 +144,7 @@ def singlepoint(
read_kwargs=read_kwargs,
calc_kwargs=calc_kwargs,
)
print(s_point.run_single_point(properties=properties))
print(s_point.run_single_point(properties=properties, write_kwargs=write_kwargs))


@app.command()
Expand Down
20 changes: 20 additions & 0 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pathlib import Path

from ase.io import read
from typer.testing import CliRunner

from janus_core.cli import app
Expand Down Expand Up @@ -90,3 +91,22 @@ def test_singlepoint_calc_kwargs():
)
assert result.exit_code == 0
assert "Using float32 for MACECalculator" in result.stdout


def test_singlepoint_write_kwargs(tmp_path):
"""Test setting write_kwargs for singlepoint calculation."""
result = runner.invoke(
app,
[
"singlepoint",
"--structure",
DATA_PATH / "NaCl.cif",
"--write-kwargs",
f"{{'filename': '{str(tmp_path / 'NaCl.xyz')}'}}",
"--property",
"energy",
],
)
assert result.exit_code == 0
atoms = read(tmp_path / "NaCl.xyz")
assert "forces" in atoms.arrays

0 comments on commit 508e659

Please sign in to comment.