diff --git a/docs/source/apidoc/janus_core.rst b/docs/source/apidoc/janus_core.rst index 1a65f026..c5d5041a 100644 --- a/docs/source/apidoc/janus_core.rst +++ b/docs/source/apidoc/janus_core.rst @@ -278,6 +278,7 @@ janus\_core.processing.observables module :private-members: :undoc-members: :show-inheritance: + :inherited-members: janus\_core.processing.post\_process module ------------------------------------------- diff --git a/docs/source/conf.py b/docs/source/conf.py index 1ca0dcbb..18589ceb 100755 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -203,6 +203,7 @@ ("py:class", "Architectures"), ("py:class", "Devices"), ("py:class", "MaybeSequence"), + ("py:class", "SliceLike"), ("py:class", "PathLike"), ("py:class", "Atoms"), ("py:class", "Calculator"), diff --git a/docs/source/developer_guide/tutorial.rst b/docs/source/developer_guide/tutorial.rst index e2154ccf..60001157 100644 --- a/docs/source/developer_guide/tutorial.rst +++ b/docs/source/developer_guide/tutorial.rst @@ -187,20 +187,87 @@ Alternatively, using ``tox``:: Adding a new Observable ======================= -Additional built-in observable quantities may be added for use by the ``janus_core.processing.correlator.Correlation`` class. These should conform to the ``__call__`` signature of ``janus_core.helpers.janus_types.Observable``. For a user this can be accomplished by writing a function, or class also implementing a commensurate ``__call__``. +A :class:`janus_core.processing.observables.Observable` abstracts obtaining a quantity derived from ``Atoms``. They may be used as kernels for input into analysis such as a correlation. -Built-in observables are collected within the ``janus_core.processing.observables`` module. For example the ``janus_core.processing.observables.Stress`` observable allows a user to quickly setup a given correlation of stress tensor components (with and without the ideal gas contribution). An observable for the ``xy`` component is obtained without the ideal gas contribution as: +Additional built-in observable quantities may be added for use by the :class:`janus_core.processing.correlator.Correlation` class. These should extend :class:`janus_core.processing.observables.Observable` and are implemented within the :py:mod:`janus_core.processing.observables` module. + +The abstract method ``__call__`` should be implemented to obtain the values of the observed quantity from an ``Atoms`` object. When used as part of a :class:`janus_core.processing.correlator.Correlation`, each value will be correlated and the results averaged. + +As an example of building a new ``Observable`` consider the :class:`janus_core.processing.observables.Stress` built-in. The following steps may be taken: + +1. Defining the observable. +--------------------------- + +The stress tensor may be computed on an atoms object using ``Atoms.get_stress``. A user may wish to obtain a particular component, or perhaps only compute the stress on some subset of ``Atoms``. For example during a :class:`janus_core.calculations.md.MolecularDynamics` run a user may wish to correlate only the off-diagonal components (shear stress), computed across all atoms. + +2. Writing the ``__call__`` method. +----------------------------------- + +In the call method we can use the base :class:`janus_core.processing.observables.Observable`'s optional atom selector ``atoms_slice`` to first define the subset of atoms to compute the stress for: .. code-block:: python - Stress("xy", False) + def __call__(self, atoms: Atoms) -> list[float]: + sliced_atoms = atoms[self.atoms_slice] + # must be re-attached after slicing for get_stress + sliced_atoms.calc = atoms.calc -A new built-in observables can be implemented by a class with the method: +Next the stresses may be obtained from: .. code-block:: python - def __call__(self, atoms: Atoms, *args, **kwargs) -> float + stresses = ( + sliced_atoms.get_stress( + include_ideal_gas=self.include_ideal_gas, voigt=True + ) + / units.GPa + ) + +Finally, to facilitate handling components in a symbolic way, :class:`janus_core.processing.observables.ComponentMixin` exists to parse ``str`` symbolic components to ``int`` indices by defining a suitable mapping. For the stress tensor (and the format of ``Atoms.get_stress``) a suitable mapping is defined in :class:`janus_core.processing.observables.Stress`'s ``__init__`` method: -The ``__call__`` should contain all the logic for obtaining some ``float`` value from an ``Atoms`` object, alongside optional positional arguments and kwargs. The args and kwargs are set by a user when specifying correlations for a ``janus_core.calculations.md.MolecularDynamics`` run. See also ``janus_core.helpers.janus_types.CorrelationKwargs``. These are set at the instantiation of the ``janus_core.calculations.md.MolecularDynamics`` object and are not modified. These could be used e.g. to specify an observable calculated only from one atom's data. +.. code-block:: python + + ComponentMixin.__init__( + self, + components={ + "xx": 0, + "yy": 1, + "zz": 2, + "yz": 3, + "zy": 3, + "xz": 4, + "zx": 4, + "xy": 5, + "yx": 5, + }, + ) + +This then concludes the ``__call__`` method for :class:`janus_core.processing.observables.Stress` by using :class:`janus_core.processing.observables.ComponentMixin`'s +pre-calculated indices: + +.. code-block:: python + + return stesses[self._indices] + +The combination of the above means a user may obtain, say, the ``xy`` and ``zy`` stress tensor components over odd-indexed atoms by calling the following observable on an ``Atoms``: + +.. code-block:: python + + s = Stress(components=["xy", "zy"], atoms_slice=(0, None, 2)) + + +Since usually total system stresses are required we can define two built-ins to handle the shear and hydrostatic stresses like so: + +.. code-block:: python + + StressHydrostatic = Stress(components=["xx", "yy", "zz"]) + StressShear = Stress(components=["xy", "yz", "zx"]) + +Where by default :class:`janus_core.processing.observables.Observable`'s ``atoms_slice`` is ``slice(0, None, 1)``, which expands to all atoms in an ``Atoms``. + +For comparison the :class:`janus_core.processing.observables.Velocity` built-in's ``__call__`` not only returns atom velocity for the requested components, but also returns them for every tracked atom i.e: + +.. code-block:: python -``janus_core.processing.observables.Stress`` includes a constructor to take a symbolic component, e.g. ``"xx"`` or ``"yz"``, and determine the index required from ``ase.Atoms.get_stress`` on instantiation for ease of use. + def __call__(self, atoms: Atoms) -> list[float]: + return atoms.get_velocities()[self.atoms_slice, :][:, self._indices].flatten() diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index 708c7845..27c6c556 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -6,22 +6,16 @@ from enum import Enum import logging from pathlib import Path, PurePath -from typing import ( - IO, - Literal, - Optional, - Protocol, - TypedDict, - TypeVar, - Union, - runtime_checkable, -) +from typing import IO, TYPE_CHECKING, Literal, Optional, TypedDict, TypeVar, Union from ase import Atoms from ase.eos import EquationOfState import numpy as np from numpy.typing import NDArray +if TYPE_CHECKING: + from janus_core.processing.observables import Observable + # General T = TypeVar("T") @@ -86,32 +80,13 @@ class PostProcessKwargs(TypedDict, total=False): vaf_output_file: PathLike | None -@runtime_checkable -class Observable(Protocol): - """Signature for correlation observable getter.""" - - def __call__(self, atoms: Atoms, *args, **kwargs) -> float: - """ - Call the getter. - - Parameters - ---------- - atoms : Atoms - Atoms object to extract values from. - *args : tuple - Additional positional arguments passed to getter. - **kwargs : dict - Additional kwargs passed getter. - """ - - class CorrelationKwargs(TypedDict, total=True): """Arguments for on-the-fly correlations .""" #: observable a in , with optional args and kwargs - a: Observable | tuple[Observable, tuple, dict] + a: Observable #: observable b in , with optional args and kwargs - b: Observable | tuple[Observable, tuple, dict] + b: Observable #: name used for correlation in output name: str #: blocks used in multi-tau algorithm diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index 9ac0e58f..32ed6043 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -18,7 +18,12 @@ ) from rich.style import Style -from janus_core.helpers.janus_types import MaybeSequence, PathLike +from janus_core.helpers.janus_types import ( + MaybeSequence, + PathLike, + SliceLike, + StartStopStep, +) class FileNameMixin(ABC): # noqa: B024 (abstract-base-class-without-abstract-method) @@ -432,3 +437,83 @@ def check_files_exist(config: dict, req_file_keys: Sequence[PathLike]) -> None: # Only check if file key is in the configuration file if not Path(config[file_key]).exists(): raise FileNotFoundError(f"{config[file_key]} does not exist") + + +def validate_slicelike(maybe_slicelike: SliceLike) -> None: + """ + Raise an exception if slc is not a valid SliceLike. + + Parameters + ---------- + maybe_slicelike : SliceLike + Candidate to test. + + Raises + ------ + ValueError + If maybe_slicelike is not SliceLike. + """ + if isinstance(maybe_slicelike, (slice, range, int)): + return + if isinstance(maybe_slicelike, tuple) and len(maybe_slicelike) == 3: + start, stop, step = maybe_slicelike + if ( + (start is None or isinstance(start, int)) + and (stop is None or isinstance(stop, int)) + and isinstance(step, int) + ): + return + + raise ValueError(f"{maybe_slicelike} is not a valid SliceLike") + + +def slicelike_to_startstopstep(index: SliceLike) -> StartStopStep: + """ + Standarize `SliceLike`s into tuple of `start`, `stop`, `step`. + + Parameters + ---------- + index : SliceLike + `SliceLike` to standardize. + + Returns + ------- + StartStopStep + Standardized `SliceLike` as `start`, `stop`, `step` triplet. + """ + validate_slicelike(index) + if isinstance(index, int): + if index == -1: + return (index, None, 1) + return (index, index + 1, 1) + + if isinstance(index, (slice, range)): + return (index.start, index.stop, index.step) + + return index + + +def selector_len(slc: SliceLike | list, selectable_length: int) -> int: + """ + Calculate the length of a selector applied to an indexable of a given length. + + Parameters + ---------- + slc : Union[SliceLike, list] + The applied SliceLike or list for selection. + selectable_length : int + The length of the selectable object. + + Returns + ------- + int + Length of the result of applying slc. + """ + if isinstance(slc, int): + return 1 + if isinstance(slc, list): + return len(slc) + start, stop, step = slicelike_to_startstopstep(slc) + if stop is None: + stop = selectable_length + return len(range(start, stop, step)) diff --git a/janus_core/processing/correlator.py b/janus_core/processing/correlator.py index 39752a82..10f18c66 100644 --- a/janus_core/processing/correlator.py +++ b/janus_core/processing/correlator.py @@ -7,7 +7,7 @@ from ase import Atoms import numpy as np -from janus_core.helpers.janus_types import Observable +from janus_core.processing.observables import Observable class Correlator: @@ -140,19 +140,39 @@ def _shifts_valid(self, block: int, p_i: int, p_j: int) -> bool: """ return self._shift_not_null[block, p_i] and self._shift_not_null[block, p_j] - def get(self) -> tuple[Iterable[float], Iterable[float]]: + def get_lags(self) -> Iterable[float]: """ - Obtain the correlation and lag times. + Obtain the correlation lag times. Returns ------- - correlation : Iterable[float] + Iterable[float] + The correlation lag times. + """ + lags = np.zeros(self._points * self._blocks) + + lag = 0 + for i in range(self._points): + if self._count_correlated[0, i] > 0: + lags[lag] = i + lag += 1 + for k in range(1, self._max_block_used): + for i in range(self._min_dist, self._points): + if self._count_correlated[k, i] > 0: + lags[lag] = float(i) * float(self._averaging) ** k + lag += 1 + return lags[0:lag] + + def get_value(self) -> Iterable[float]: + """ + Obtain the correlation value. + + Returns + ------- + Iterable[float] The correlation values . - lags : Iterable[float]] - The correlation lag times t'. """ correlation = np.zeros(self._points * self._blocks) - lags = np.zeros(self._points * self._blocks) lag = 0 for i in range(self._points): @@ -160,7 +180,6 @@ def get(self) -> tuple[Iterable[float], Iterable[float]]: correlation[lag] = ( self._correlation[0, i] / self._count_correlated[0, i] ) - lags[lag] = i lag += 1 for k in range(1, self._max_block_used): for i in range(self._min_dist, self._points): @@ -168,9 +187,8 @@ def get(self) -> tuple[Iterable[float], Iterable[float]]: correlation[lag] = ( self._correlation[k, i] / self._count_correlated[k, i] ) - lags[lag] = float(i) * float(self._averaging) ** k lag += 1 - return (correlation[0:lag], lags[0:lag]) + return correlation[0:lag] class Correlation: @@ -179,10 +197,10 @@ class Correlation: Parameters ---------- - a : tuple[Observable, dict] - Getter for a and kwargs. - b : tuple[Observable, dict] - Getter for b and kwargs. + a : Observable + Observable for a. + b : Observable + Observable for b. name : str Name of correlation. blocks : int @@ -197,8 +215,9 @@ class Correlation: def __init__( self, - a: Observable | tuple[Observable, tuple, dict], - b: Observable | tuple[Observable, tuple, dict], + *, + a: Observable, + b: Observable, name: str, blocks: int, points: int, @@ -210,10 +229,10 @@ def __init__( Parameters ---------- - a : tuple[Observable, tuple, dict] - Getter for a and kwargs. - b : tuple[Observable, tuple, dict] - Getter for b and kwargs. + a : Observable + Observable for a. + b : Observable + Observable for b. name : str Name of correlation. blocks : int @@ -226,19 +245,13 @@ def __init__( Frequency to update the correlation, md steps. """ self.name = name - if isinstance(a, tuple): - self._get_a, self._a_args, self._a_kwargs = a - else: - self._get_a = a - self._a_args, self._a_kwargs = (), {} + self.blocks = blocks + self.points = points + self.averaging = averaging + self._get_a = a + self._get_b = b - if isinstance(b, tuple): - self._get_b, self._b_args, self._b_kwargs = b - else: - self._get_b = b - self._b_args, self._b_kwargs = (), {} - - self._correlator = Correlator(blocks=blocks, points=points, averaging=averaging) + self._correlators = None self._update_frequency = update_frequency @property @@ -262,14 +275,20 @@ def update(self, atoms: Atoms) -> None: atoms : Atoms Atoms object to observe values from. """ - self._correlator.update( - self._get_a(atoms, *self._a_args, **self._a_kwargs), - self._get_b(atoms, *self._b_args, **self._b_kwargs), - ) + value_pairs = zip(self._get_a(atoms), self._get_b(atoms)) + if self._correlators is None: + self._correlators = [ + Correlator( + blocks=self.blocks, points=self.points, averaging=self.averaging + ) + for _ in range(len(self._get_a(atoms))) + ] + for corr, values in zip(self._correlators, value_pairs): + corr.update(*values) def get(self) -> tuple[Iterable[float], Iterable[float]]: """ - Get the correlation value and lags. + Get the correlation value and lags, averaging over atoms if applicable. Returns ------- @@ -278,7 +297,10 @@ def get(self) -> tuple[Iterable[float], Iterable[float]]: lags : Iterable[float]] The correlation lag times t'. """ - return self._correlator.get() + if self._correlators: + lags = self._correlators[0].get_lags() + return np.mean([cor.get_value() for cor in self._correlators], axis=0), lags + return [], [] def __str__(self) -> str: """ diff --git a/janus_core/processing/observables.py b/janus_core/processing/observables.py index 32eee246..52fe0b91 100644 --- a/janus_core/processing/observables.py +++ b/janus_core/processing/observables.py @@ -2,73 +2,268 @@ from __future__ import annotations +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + from ase import Atoms, units +if TYPE_CHECKING: + from janus_core.helpers.janus_types import SliceLike + +from janus_core.helpers.utils import slicelike_to_startstopstep + + +# pylint: disable=too-few-public-methods +class Observable(ABC): + """ + Observable data that may be correlated. + + Parameters + ---------- + atoms_slice : list[int] | SliceLike | None = None + A slice of atoms to observe. + """ + + def __init__(self, atoms_slice: list[int] | SliceLike | None = None): + """ + Initialise an observable with a given dimensionality. + + Parameters + ---------- + atoms_slice : list[int] | SliceLike | None + A slice of atoms to observe. By default all atoms are included. + """ + if not atoms_slice: + self.atoms_slice = slice(0, None, 1) + return + + if isinstance(atoms_slice, list): + self.atoms_slice = atoms_slice + else: + self.atoms_slice = slice(*slicelike_to_startstopstep(atoms_slice)) + + @abstractmethod + def __call__(self, atoms: Atoms) -> list[float]: + """ + Signature for returning observed value from atoms. + + Parameters + ---------- + atoms : Atoms + Atoms object to extract values from. + + Returns + ------- + list[float] + The observed value, with dimensions atoms by self.dimension. + """ + + +class ComponentMixin: + """ + Mixin to handle Observables with components. -class Stress: + Parameters + ---------- + components : dict[str, int] + Symbolic components mapped to indices. + """ + + def __init__(self, components: dict[str, int]): + """ + Initialise the mixin with components. + + Parameters + ---------- + components : dict[str, int] + Symbolic components mapped to indices. + """ + self._allowed_components = components + + @property + def _indices(self) -> list[int]: + """ + Get indices associated with self.components. + + Returns + ------- + list[int] + The indices for each self.components. + """ + return [self._allowed_components[c] for c in self.components] + + @property + def components(self) -> list[str]: + """ + Get the symbolic components of the observable. + + Returns + ------- + list[str] + The observables components. + """ + return self._components + + @components.setter + def components(self, components: list[str]): + """ + Check if components are valid, if so set them. + + Parameters + ---------- + components : str + The component symbols to check. + + Raises + ------ + ValueError + If any component is invalid. + """ + if any(components - self._allowed_components.keys()): + raise ValueError( + f"'{components-self._allowed_components.keys()}'" + f" invalid, must be '{', '.join(self._allowed_components)}'" + ) + + self._components = components + + +# pylint: disable=too-few-public-methods +class Stress(Observable, ComponentMixin): """ Observable for stress components. Parameters ---------- - component : str - Symbol for tensor components, xx, yy, etc. + components : list[str] + Symbols for correlated tensor components, xx, yy, etc. + atoms_slice : list[int] | SliceLike | None = None + List or slice of atoms to observe velocities from. include_ideal_gas : bool Calculate with the ideal gas contribution. """ - def __init__(self, component: str, *, include_ideal_gas: bool = True) -> None: + def __init__( + self, + *, + components: list[str], + atoms_slice: list[int] | SliceLike | None = None, + include_ideal_gas: bool = True, + ): """ - Initialise the observables from a symbolic str component. + Initialise the observable from a symbolic str component. Parameters ---------- - component : str - Symbol for tensor components, xx, yy, etc. + components : list[str] + Symbols for tensor components, xx, yy, etc. + atoms_slice : list[int] | SliceLike | None = None + List or slice of atoms to observe velocities from. include_ideal_gas : bool Calculate with the ideal gas contribution. """ - components = { - "xx": 0, - "yy": 1, - "zz": 2, - "yz": 3, - "zy": 3, - "xz": 4, - "zx": 4, - "xy": 5, - "yx": 5, - } - if component not in components: - raise ValueError( - f"'{component}' invalid, must be '{', '.join(list(components.keys()))}'" - ) + ComponentMixin.__init__( + self, + components={ + "xx": 0, + "yy": 1, + "zz": 2, + "yz": 3, + "zy": 3, + "xz": 4, + "zx": 4, + "xy": 5, + "yx": 5, + }, + ) + self.components = components - self.component = component - self._index = components[self.component] + Observable.__init__(self, atoms_slice) self.include_ideal_gas = include_ideal_gas - def __call__(self, atoms: Atoms, *args, **kwargs) -> float: + def __call__(self, atoms: Atoms) -> list[float]: """ - Get the stress component. + Get the stress components. Parameters ---------- atoms : Atoms Atoms object to extract values from. - *args : tuple - Additional positional arguments passed to getter. - **kwargs : dict - Additional kwargs passed getter. Returns ------- - float - The stress component in GPa units. + list[float] + The stress components in GPa units. + + Raises + ------ + ValueError + If atoms is not an Atoms object. """ - return ( - atoms.get_stress(include_ideal_gas=self.include_ideal_gas, voigt=True)[ - self._index - ] + if not isinstance(atoms, Atoms): + raise ValueError("atoms should be an Atoms object") + sliced_atoms = atoms[self.atoms_slice] + # must be re-attached after slicing for get_stress + sliced_atoms.calc = atoms.calc + stresses = ( + sliced_atoms.get_stress( + include_ideal_gas=self.include_ideal_gas, voigt=True + ) / units.GPa ) + return stresses[self._indices] + + +StressHydrostatic = Stress(components=["xx", "yy", "zz"]) +StressShear = Stress(components=["xy", "yz", "zx"]) + + +# pylint: disable=too-few-public-methods +class Velocity(Observable, ComponentMixin): + """ + Observable for per atom velocity components. + + Parameters + ---------- + components : list[str] + Symbols for velocity components, x, y, z. + atoms_slice : list[int] | SliceLike | None = None + List or slice of atoms to observe velocities from. + """ + + def __init__( + self, + *, + components: list[str], + atoms_slice: list[int] | SliceLike | None = None, + ): + """ + Initialise the observable from a symbolic str component and atom index. + + Parameters + ---------- + components : list[str] + Symbols for tensor components, x, y, and z. + atoms_slice : Union[list[int], SliceLike] + List or slice of atoms to observe velocities from. + """ + ComponentMixin.__init__(self, components={"x": 0, "y": 1, "z": 2}) + self.components = components + + Observable.__init__(self, atoms_slice) + + def __call__(self, atoms: Atoms) -> list[float]: + """ + Get the velocity components for correlated atoms. + + Parameters + ---------- + atoms : Atoms + Atoms object to extract values from. + + Returns + ------- + list[float] + The velocity values. + """ + return atoms.get_velocities()[self.atoms_slice, :][:, self._indices].flatten() diff --git a/janus_core/processing/post_process.py b/janus_core/processing/post_process.py index 044b8d02..2c9a66ff 100644 --- a/janus_core/processing/post_process.py +++ b/janus_core/processing/post_process.py @@ -15,33 +15,8 @@ MaybeSequence, PathLike, SliceLike, - StartStopStep, ) - - -def _process_index(index: SliceLike) -> StartStopStep: - """ - Standarize `SliceLike`s into tuple of `start`, `stop`, `step`. - - Parameters - ---------- - index : SliceLike - `SliceLike` to standardize. - - Returns - ------- - StartStopStep - Standardized `SliceLike` as `start`, `stop`, `step` triplet. - """ - if isinstance(index, int): - if index == -1: - return (index, None, 1) - return (index, index + 1, 1) - - if isinstance(index, (slice, range)): - return (index.start, index.stop, index.step) - - return index +from janus_core.helpers.utils import slicelike_to_startstopstep def compute_rdf( @@ -94,7 +69,7 @@ def compute_rdf( If `by_elements` is true returns a `dict` of RDF by element pairs. Otherwise returns RDF of total system filtered by elements. """ - index = _process_index(index) + index = slicelike_to_startstopstep(index) if not isinstance(data, Sequence): data = [data] @@ -261,7 +236,7 @@ def compute_vaf( ) # Extract requested data - index = _process_index(index) + index = slicelike_to_startstopstep(index) data = data[slice(*index)] if use_velocities: diff --git a/tests/test_correlator.py b/tests/test_correlator.py index 91b390e5..ded1aba0 100644 --- a/tests/test_correlator.py +++ b/tests/test_correlator.py @@ -5,18 +5,18 @@ from collections.abc import Iterable from pathlib import Path -from ase import Atoms from ase.io import read from ase.units import GPa import numpy as np from pytest import approx from typer.testing import CliRunner -from yaml import Loader, load +from yaml import Loader, load, safe_load from janus_core.calculations.md import NVE from janus_core.calculations.single_point import SinglePoint +from janus_core.processing import post_process from janus_core.processing.correlator import Correlator -from janus_core.processing.observables import Stress +from janus_core.processing.observables import Stress, Velocity DATA_PATH = Path(__file__).parent / "data" MODEL_PATH = Path(__file__).parent / "models" / "mace_mp_small.model" @@ -43,7 +43,7 @@ def correlate( def test_setup(): """Test initial values.""" cor = Correlator(blocks=1, points=100, averaging=2) - correlation, lags = cor.get() + correlation, lags = cor.get_value(), cor.get_lags() assert len(correlation) == len(lags) assert len(correlation) == 0 @@ -55,7 +55,7 @@ def test_correlation(): signal = np.exp(-np.linspace(0.0, 1.0, points)) for val in signal: cor.update(val, val) - correlation, lags = cor.get() + correlation, lags = cor.get_value(), cor.get_lags() direct = correlate(signal, signal, fft=False) fft = correlate(signal, signal, fft=True) @@ -66,8 +66,8 @@ def test_correlation(): assert fft == approx(correlation, rel=1e-10) -def test_md_correlations(tmp_path): - """Test correlations as part of MD cycle.""" +def test_vaf(tmp_path): + """Test the correlator against post-process.""" file_prefix = tmp_path / "Cl4Na4-nve-T300.0" traj_path = tmp_path / "Cl4Na4-nve-T300.0-traj.extxyz" cor_path = tmp_path / "Cl4Na4-nve-T300.0-cor.dat" @@ -78,14 +78,8 @@ def test_md_correlations(tmp_path): calc_kwargs={"model": MODEL_PATH}, ) - def user_observable_a(atoms: Atoms, kappa, **kwargs) -> float: - """User specified getter for correlation.""" - return ( - kwargs["gamma"] - * kappa - * atoms.get_stress(include_ideal_gas=True, voigt=True)[-1] - / GPa - ) + na = list(range(0, len(single_point.struct), 2)) + cl = list(range(1, len(single_point.struct), 2)) nve = NVE( struct=single_point.struct, @@ -97,17 +91,73 @@ def user_observable_a(atoms: Atoms, kappa, **kwargs) -> float: file_prefix=file_prefix, correlation_kwargs=[ { - "a": (user_observable_a, (2,), {"gamma": 2}), - "b": Stress("xy"), - "name": "user_correlation", + "a": Velocity(components=["x", "y", "z"], atoms_slice=(0, None, 2)), + "b": Velocity( + components=["x", "y", "z"], + atoms_slice=range(0, len(single_point.struct), 2), + ), + "name": "vaf_Na", "blocks": 1, "points": 11, "averaging": 1, "update_frequency": 1, }, { - "a": Stress("xy"), - "b": Stress("xy"), + "a": Velocity(components=["x", "y", "z"], atoms_slice=cl), + "b": Velocity( + components=["x", "y", "z"], atoms_slice=slice(1, None, 2) + ), + "name": "vaf_Cl", + "blocks": 1, + "points": 11, + "averaging": 1, + "update_frequency": 1, + }, + ], + write_kwargs={"invalidate_calc": False}, + ) + + nve.run() + + assert cor_path.exists() + assert traj_path.exists() + + traj = read(traj_path, index=":") + vaf_post = post_process.compute_vaf( + traj, use_velocities=True, filter_atoms=(na, cl) + ) + with open(cor_path) as cor: + vaf = safe_load(cor) + vaf_na = np.array(vaf["vaf_Na"]["value"]) + vaf_cl = np.array(vaf["vaf_Cl"]["value"]) + assert vaf_na * 3 == approx(vaf_post[1][0], rel=1e-5) + assert vaf_cl * 3 == approx(vaf_post[1][1], rel=1e-5) + + +def test_md_correlations(tmp_path): + """Test correlations as part of MD cycle.""" + file_prefix = tmp_path / "Cl4Na4-nve-T300.0" + traj_path = tmp_path / "Cl4Na4-nve-T300.0-traj.extxyz" + cor_path = tmp_path / "Cl4Na4-nve-T300.0-cor.dat" + + single_point = SinglePoint( + struct_path=DATA_PATH / "NaCl.cif", + arch="mace", + calc_kwargs={"model": MODEL_PATH}, + ) + + nve = NVE( + struct=single_point.struct, + temp=300.0, + steps=10, + seed=1, + traj_every=1, + stats_every=1, + file_prefix=file_prefix, + correlation_kwargs=[ + { + "a": Stress(components=[("xy")]), + "b": Stress(components=[("xy")]), "name": "stress_xy_auto_cor", "blocks": 1, "points": 11, @@ -127,8 +177,7 @@ def user_observable_a(atoms: Atoms, kappa, **kwargs) -> float: assert cor_path.exists() with open(cor_path, encoding="utf8") as in_file: cor = load(in_file, Loader=Loader) - assert len(cor) == 2 - assert "user_correlation" in cor + assert len(cor) == 1 assert "stress_xy_auto_cor" in cor stress_cor = cor["stress_xy_auto_cor"] @@ -138,11 +187,3 @@ def user_observable_a(atoms: Atoms, kappa, **kwargs) -> float: direct = correlate(pxy, pxy, fft=False) # input data differs due to i/o, error is expected 1e-5 assert direct == approx(value, rel=1e-5) - - user_cor = cor["user_correlation"] - value, lags = user_cor["value"], stress_cor["lags"] - assert len(value) == len(lags) == 11 - - direct = correlate([v * 4.0 for v in pxy], pxy, fft=False) - # input data differs due to i/o, error is expected 1e-5 - assert direct == approx(value, rel=1e-5) diff --git a/tests/test_utils.py b/tests/test_utils.py index 32fb9ecc..ad6d7868 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -9,9 +9,15 @@ import pytest from janus_core.cli.utils import dict_paths_to_strs, dict_remove_hyphens +from janus_core.helpers.janus_types import SliceLike, StartStopStep from janus_core.helpers.mlip_calculators import choose_calculator from janus_core.helpers.struct_io import output_structs -from janus_core.helpers.utils import none_to_dict +from janus_core.helpers.utils import ( + none_to_dict, + selector_len, + slicelike_to_startstopstep, + validate_slicelike, +) DATA_PATH = Path(__file__).parent / "data/NaCl.cif" MODEL_PATH = Path(__file__).parent / "models/mace_mp_small.model" @@ -166,3 +172,81 @@ def test_none_to_dict(dicts_in): assert dicts[2] == dicts_in[2] assert dicts[3] == dicts_in[3] assert dicts[4] == {} + + +@pytest.mark.parametrize( + "slc, expected", + [ + ((1, 2, 3), (1, 2, 3)), + (1, (1, 2, 1)), + (range(1, 2, 3), (1, 2, 3)), + (slice(1, 2, 3), (1, 2, 3)), + (-1, (-1, None, 1)), + (range(10), (0, 10, 1)), + (slice(0, None, 1), (0, None, 1)), + ], +) +def test_slicelike_to_startstopstep(slc: SliceLike, expected: StartStopStep): + """Test converting SliceLike to StartStopStep.""" + assert slicelike_to_startstopstep(slc) == expected + + +@pytest.mark.parametrize( + "slc, len, expected", + [ + ((1, 2, 3), 3, 1), + (1, 1, 1), + (range(1, 2, 3), 3, 1), + (slice(1, 2, 3), 3, 1), + (-1, 5, 1), + (-3, 4, 1), + (range(10), 10, 10), + (slice(0, None, 2), 10, 5), + ([-2, -1, 0], 9, 3), + ([-1], 10, 1), + ([0, -1, 2, 9], 10, 4), + ], +) +def test_selector_len(slc: SliceLike | list[int], len: int, expected: int): + """Test converting SliceLike to StartStopStep.""" + assert selector_len(slc, len) == expected + + +@pytest.mark.parametrize( + "slc", + [ + slice(0, 1, 1), + slice(0, None, 1), + range(3), + range(0, 10, 1), + -1, + 0, + 1, + (0), + (0, 1, 1), + (0, None, 1), + ], +) +def test_valid_slicelikes(slc): + """Test validate_slicelike on valid SliceLikes.""" + validate_slicelike(slc) + + +@pytest.mark.parametrize( + "slc", + [ + 1.0, + "", + None, + [1], + (None, 0, None), + (0, 1, None), + (None, None, None), + (0, 1), + (0, 1, 2, 3), + ], +) +def test_invalid_slicelikes(slc): + """Test validate_slicelike on invalid SliceLikes.""" + with pytest.raises(ValueError): + validate_slicelike(slc)