Skip to content

Commit

Permalink
Add summary log for MD
Browse files Browse the repository at this point in the history
  • Loading branch information
ElliottKasoar committed Mar 27, 2024
1 parent b579947 commit d96cb8f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 13 deletions.
72 changes: 62 additions & 10 deletions janus_core/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Set up commandline interface."""

import ast
import datetime
from pathlib import Path
from typing import Annotated, Optional, get_args

import typer
import yaml

from janus_core.geom_opt import optimize
from janus_core.janus_types import Ensembles
Expand Down Expand Up @@ -47,7 +49,7 @@ def __str__(self):
return f"<TyperDict: value={self.value}>"


def parse_dict_class(value: str):
def _parse_dict_class(value: str):
"""
Convert string input into a dictionary.
Expand All @@ -64,7 +66,7 @@ def parse_dict_class(value: str):
return TyperDict(ast.literal_eval(value))


def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
def _parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
"""
Convert list of TyperDict objects to list of dictionaries.
Expand All @@ -90,6 +92,22 @@ def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
return typer_dicts


def _iter_path_to_str(dictionary: dict) -> None:
"""
Recursively iterate over dictionary, converting Path values to strings.
Parameters
----------
dictionary : dict
Dictionary to be converted.
"""
for key, value in dictionary.items():
if isinstance(value, dict):
_iter_path_to_str(value)
elif isinstance(value, Path):
dictionary[key] = str(value)


# Shared type aliases
StructPath = Annotated[
Path, typer.Option("--struct", help="Path of structure to simulate.")
Expand All @@ -101,7 +119,7 @@ def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
ReadKwargs = Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
parser=_parse_dict_class,
help=(
"""
Keyword arguments to pass to ase.io.read. Must be passed as a dictionary
Expand All @@ -114,7 +132,7 @@ def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
CalcKwargs = Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
parser=_parse_dict_class,
help=(
"""
Keyword arguments to pass to selected calculator. Must be passed as a
Expand All @@ -128,7 +146,7 @@ def parse_typer_dicts(typer_dicts: list[TyperDict]) -> list[dict]:
WriteKwargs = Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
parser=_parse_dict_class,
help=(
"""
Keyword arguments to pass to ase.io.write when saving results. Must be
Expand Down Expand Up @@ -198,7 +216,7 @@ def singlepoint(
log_file : Optional[Path]
Path to write logs to. Default is "singlepoint.log".
"""
[read_kwargs, calc_kwargs, write_kwargs] = parse_typer_dicts(
[read_kwargs, calc_kwargs, write_kwargs] = _parse_typer_dicts(
[read_kwargs, calc_kwargs, write_kwargs]
)

Expand Down Expand Up @@ -269,7 +287,7 @@ def geomopt( # pylint: disable=too-many-arguments,too-many-locals
opt_kwargs: Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
parser=_parse_dict_class,
help=(
"""
Keyword arguments to pass to optimizer. Must be passed as a dictionary
Expand Down Expand Up @@ -322,7 +340,7 @@ def geomopt( # pylint: disable=too-many-arguments,too-many-locals
log_file : Optional[Path]
Path to write logs to. Default is "geomopt.log".
"""
[read_kwargs, calc_kwargs, opt_kwargs, write_kwargs] = parse_typer_dicts(
[read_kwargs, calc_kwargs, opt_kwargs, write_kwargs] = _parse_typer_dicts(
[read_kwargs, calc_kwargs, opt_kwargs, write_kwargs]
)

Expand Down Expand Up @@ -433,7 +451,7 @@ def md( # pylint: disable=too-many-arguments,too-many-locals,invalid-name
minimize_kwargs: Annotated[
TyperDict,
typer.Option(
parser=parse_dict_class,
parser=_parse_dict_class,
help=(
"""
Keyword arguments to pass to optimizer. Must be passed as a dictionary
Expand Down Expand Up @@ -513,6 +531,10 @@ def md( # pylint: disable=too-many-arguments,too-many-locals,invalid-name
Optional[int],
typer.Option(help="Random seed for numpy.random and random functions."),
] = None,
summary: Annotated[
Path,
typer.Option(help="Path to save summary of inputs and start/end time."),
] = "summary.yml",
):
"""
Run molecular dynamics simulation, and save trajectory and statistics.
Expand Down Expand Up @@ -594,8 +616,10 @@ def md( # pylint: disable=too-many-arguments,too-many-locals,invalid-name
seed : Optional[int]
Random seed used by numpy.random and random functions, such as in Langevin.
Default is None.
summary : Path
Path to save summary of inputs and start/end time.
"""
[read_kwargs, calc_kwargs, minimize_kwargs] = parse_typer_dicts(
[read_kwargs, calc_kwargs, minimize_kwargs] = _parse_typer_dicts(
[read_kwargs, calc_kwargs, minimize_kwargs]
)

Expand Down Expand Up @@ -648,6 +672,27 @@ def md( # pylint: disable=too-many-arguments,too-many-locals,invalid-name
"seed": seed,
}

inputs = {"ensemble": ensemble}
for key, value in dyn_kwargs.items():
inputs[key] = value

inputs["struct"] = {
"struct_path": str(struct_path),
"n_atoms": len(s_point.struct),
"formula": s_point.struct.get_chemical_formula(),
}

# Convert all paths to strings
_iter_path_to_str(inputs)

save_info = [
{"command": "janus md"},
{"start_time": datetime.datetime.now().strftime("%d/%m/%Y, %H:%M:%S")},
{"inputs": inputs},
]
with open(summary, "w", encoding="utf8") as outfile:
yaml.dump(save_info, outfile, default_flow_style=False)

if ensemble == "nvt":
for key in ["thermostat_time", "barostat_time", "bulk_modulus", "pressure"]:
del dyn_kwargs[key]
Expand Down Expand Up @@ -679,3 +724,10 @@ def md( # pylint: disable=too-many-arguments,too-many-locals,invalid-name
dyn = NVT_NH(**dyn_kwargs)

dyn.run()

with open(summary, "a", encoding="utf8") as outfile:
yaml.dump(
[{"end_time": datetime.datetime.today().strftime("%d/%m/%Y, %H:%M:%S")}],
outfile,
default_flow_style=False,
)
6 changes: 3 additions & 3 deletions janus_core/md.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Run molecular dynamics simulations."""

import datetime as clock
import datetime
from pathlib import Path
import random
from typing import Any, Optional
Expand Down Expand Up @@ -370,7 +370,7 @@ def get_log_stats(self) -> str:
self.dyn.atoms.info["time_fs"] = time
self.dyn.atoms.info["step"] = step

time_now = clock.datetime.now()
time_now = datetime.datetime.now()
real_time = time_now - self.dyn.atoms.info["real_time"]
self.dyn.atoms.info["real_time"] = time_now

Expand Down Expand Up @@ -451,7 +451,7 @@ def run(self) -> None:
if self.logger:
self.logger.info("Starting molecular dynamics simulation")

self.struct.info["real_time"] = clock.datetime.now()
self.struct.info["real_time"] = datetime.datetime.now()

if self.restart:
try:
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ python = "^3.9"
ase = "^3.22.1"
mace-torch = "^0.3.4"
typer = "^0.9.0"
pyaml = "^23.12.0"

[tool.poetry.group.dev.dependencies]
coverage = {extras = ["toml"], version = "^7.4.1"}
Expand Down
54 changes: 54 additions & 0 deletions tests/test_md_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ase.io import read
import pytest
from typer.testing import CliRunner
import yaml

from janus_core.cli import app

Expand Down Expand Up @@ -36,6 +37,7 @@ def test_md(ensemble, tmp_path):
"""Test all MD simulations are able to run."""
file_prefix = tmp_path / f"{ensemble}-T300"
traj_path = tmp_path / f"{ensemble}-T300-traj.xyz"
summary_path = tmp_path / "summary.yml"

result = runner.invoke(
app,
Expand All @@ -53,6 +55,8 @@ def test_md(ensemble, tmp_path):
2,
"--traj-every",
1,
"--summary",
summary_path,
],
)

Expand All @@ -67,6 +71,7 @@ def test_md_log(tmp_path, caplog):
"""Test log correctly written for MD."""
file_prefix = tmp_path / "nvt-T300"
stats_path = tmp_path / "nvt-T300-stats.dat"
summary_path = tmp_path / "summary.yml"

with caplog.at_level("INFO", logger="janus_core.md"):
result = runner.invoke(
Expand All @@ -85,6 +90,8 @@ def test_md_log(tmp_path, caplog):
20,
"--stats-every",
1,
"--summary",
summary_path,
],
)
assert result.exit_code == 0
Expand Down Expand Up @@ -114,6 +121,7 @@ def test_seed(tmp_path):
"""Test seed enables reproducable results for NVT."""
file_prefix = tmp_path / "nvt-T300"
stats_path = tmp_path / "nvt-T300-stats.dat"
summary_path = tmp_path / "summary.yml"

result_1 = runner.invoke(
app,
Expand All @@ -133,6 +141,8 @@ def test_seed(tmp_path):
20,
"--seed",
42,
"--summary",
summary_path,
],
)
assert result_1.exit_code == 0
Expand Down Expand Up @@ -164,6 +174,8 @@ def test_seed(tmp_path):
20,
"--seed",
42,
"--summary",
summary_path,
],
)
assert result_2.exit_code == 0
Expand All @@ -178,3 +190,45 @@ def test_seed(tmp_path):
for i, (stats_1, stats_2) in enumerate(zip(final_stats_1, final_stats_2)):
if i != 1:
assert stats_1 == stats_2


def test_summary(tmp_path):
"""Test summary file can be read correctly."""
file_prefix = tmp_path / "nvt-T300"
summary_path = tmp_path / "summary.yml"

result = runner.invoke(
app,
[
"md",
"--ensemble",
"nve",
"--struct",
DATA_PATH / "NaCl.cif",
"--temp",
300,
"--file-prefix",
file_prefix,
"--steps",
2,
"--traj-every",
1,
"--summary",
summary_path,
],
)

assert result.exit_code == 0

# Read summary
with open(summary_path, encoding="utf8") as file:
summary = yaml.safe_load(file)

assert "command" in summary[0]
assert "start_time" in summary[1]
assert "inputs" in summary[2]
assert "end_time" in summary[3]

assert "ensemble" in summary[2]["inputs"]
assert "struct" in summary[2]["inputs"]
assert "n_atoms" in summary[2]["inputs"]["struct"]

0 comments on commit d96cb8f

Please sign in to comment.