Skip to content

Commit

Permalink
Adds Velocity, ShearStress, StressDiagonal, observables (stfc#285)
Browse files Browse the repository at this point in the history
* Adds component Mixin and Velocity observable

* Support SliceLike as atoms in Velocity

* Move Observable into observables.py

* Remove getter

* Create SliceLike utils

* Rename builtins, multi-line error msg

* remove unneeded property

* Update developer guide

* CorrelationKwargs import Observable directly

* Simplify Stress __call__

* Add slicelike validator

* Remove value_count, update dev guide

* Check atoms is Atoms, clearer exception

* Components as a property with setter

* Use Sphinx :inherited-members:

* Split get into get_lags/get_value

---------

Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com>
Co-authored-by: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com>
Co-authored-by: Alin Marin Elena <alin@elena.re>
  • Loading branch information
4 people authored Nov 25, 2024
1 parent b547fb9 commit b603535
Show file tree
Hide file tree
Showing 10 changed files with 618 additions and 172 deletions.
1 change: 1 addition & 0 deletions docs/source/apidoc/janus_core.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ janus\_core.processing.observables module
:private-members:
:undoc-members:
:show-inheritance:
:inherited-members:

janus\_core.processing.post\_process module
-------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
("py:class", "Architectures"),
("py:class", "Devices"),
("py:class", "MaybeSequence"),
("py:class", "SliceLike"),
("py:class", "PathLike"),
("py:class", "Atoms"),
("py:class", "Calculator"),
Expand Down
81 changes: 74 additions & 7 deletions docs/source/developer_guide/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,20 +187,87 @@ Alternatively, using ``tox``::
Adding a new Observable
=======================

Additional built-in observable quantities may be added for use by the ``janus_core.processing.correlator.Correlation`` class. These should conform to the ``__call__`` signature of ``janus_core.helpers.janus_types.Observable``. For a user this can be accomplished by writing a function, or class also implementing a commensurate ``__call__``.
A :class:`janus_core.processing.observables.Observable` abstracts obtaining a quantity derived from ``Atoms``. They may be used as kernels for input into analysis such as a correlation.

Built-in observables are collected within the ``janus_core.processing.observables`` module. For example the ``janus_core.processing.observables.Stress`` observable allows a user to quickly setup a given correlation of stress tensor components (with and without the ideal gas contribution). An observable for the ``xy`` component is obtained without the ideal gas contribution as:
Additional built-in observable quantities may be added for use by the :class:`janus_core.processing.correlator.Correlation` class. These should extend :class:`janus_core.processing.observables.Observable` and are implemented within the :py:mod:`janus_core.processing.observables` module.

The abstract method ``__call__`` should be implemented to obtain the values of the observed quantity from an ``Atoms`` object. When used as part of a :class:`janus_core.processing.correlator.Correlation`, each value will be correlated and the results averaged.

As an example of building a new ``Observable`` consider the :class:`janus_core.processing.observables.Stress` built-in. The following steps may be taken:

1. Defining the observable.
---------------------------

The stress tensor may be computed on an atoms object using ``Atoms.get_stress``. A user may wish to obtain a particular component, or perhaps only compute the stress on some subset of ``Atoms``. For example during a :class:`janus_core.calculations.md.MolecularDynamics` run a user may wish to correlate only the off-diagonal components (shear stress), computed across all atoms.

2. Writing the ``__call__`` method.
-----------------------------------

In the call method we can use the base :class:`janus_core.processing.observables.Observable`'s optional atom selector ``atoms_slice`` to first define the subset of atoms to compute the stress for:

.. code-block:: python
Stress("xy", False)
def __call__(self, atoms: Atoms) -> list[float]:
sliced_atoms = atoms[self.atoms_slice]
# must be re-attached after slicing for get_stress
sliced_atoms.calc = atoms.calc
A new built-in observables can be implemented by a class with the method:
Next the stresses may be obtained from:

.. code-block:: python
def __call__(self, atoms: Atoms, *args, **kwargs) -> float
stresses = (
sliced_atoms.get_stress(
include_ideal_gas=self.include_ideal_gas, voigt=True
)
/ units.GPa
)
Finally, to facilitate handling components in a symbolic way, :class:`janus_core.processing.observables.ComponentMixin` exists to parse ``str`` symbolic components to ``int`` indices by defining a suitable mapping. For the stress tensor (and the format of ``Atoms.get_stress``) a suitable mapping is defined in :class:`janus_core.processing.observables.Stress`'s ``__init__`` method:

The ``__call__`` should contain all the logic for obtaining some ``float`` value from an ``Atoms`` object, alongside optional positional arguments and kwargs. The args and kwargs are set by a user when specifying correlations for a ``janus_core.calculations.md.MolecularDynamics`` run. See also ``janus_core.helpers.janus_types.CorrelationKwargs``. These are set at the instantiation of the ``janus_core.calculations.md.MolecularDynamics`` object and are not modified. These could be used e.g. to specify an observable calculated only from one atom's data.
.. code-block:: python
ComponentMixin.__init__(
self,
components={
"xx": 0,
"yy": 1,
"zz": 2,
"yz": 3,
"zy": 3,
"xz": 4,
"zx": 4,
"xy": 5,
"yx": 5,
},
)
This then concludes the ``__call__`` method for :class:`janus_core.processing.observables.Stress` by using :class:`janus_core.processing.observables.ComponentMixin`'s
pre-calculated indices:

.. code-block:: python
return stesses[self._indices]
The combination of the above means a user may obtain, say, the ``xy`` and ``zy`` stress tensor components over odd-indexed atoms by calling the following observable on an ``Atoms``:

.. code-block:: python
s = Stress(components=["xy", "zy"], atoms_slice=(0, None, 2))
Since usually total system stresses are required we can define two built-ins to handle the shear and hydrostatic stresses like so:

.. code-block:: python
StressHydrostatic = Stress(components=["xx", "yy", "zz"])
StressShear = Stress(components=["xy", "yz", "zx"])
Where by default :class:`janus_core.processing.observables.Observable`'s ``atoms_slice`` is ``slice(0, None, 1)``, which expands to all atoms in an ``Atoms``.

For comparison the :class:`janus_core.processing.observables.Velocity` built-in's ``__call__`` not only returns atom velocity for the requested components, but also returns them for every tracked atom i.e:

.. code-block:: python
``janus_core.processing.observables.Stress`` includes a constructor to take a symbolic component, e.g. ``"xx"`` or ``"yz"``, and determine the index required from ``ase.Atoms.get_stress`` on instantiation for ease of use.
def __call__(self, atoms: Atoms) -> list[float]:
return atoms.get_velocities()[self.atoms_slice, :][:, self._indices].flatten()
37 changes: 6 additions & 31 deletions janus_core/helpers/janus_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,16 @@
from enum import Enum
import logging
from pathlib import Path, PurePath
from typing import (
IO,
Literal,
Optional,
Protocol,
TypedDict,
TypeVar,
Union,
runtime_checkable,
)
from typing import IO, TYPE_CHECKING, Literal, Optional, TypedDict, TypeVar, Union

from ase import Atoms
from ase.eos import EquationOfState
import numpy as np
from numpy.typing import NDArray

if TYPE_CHECKING:
from janus_core.processing.observables import Observable

# General

T = TypeVar("T")
Expand Down Expand Up @@ -86,32 +80,13 @@ class PostProcessKwargs(TypedDict, total=False):
vaf_output_file: PathLike | None


@runtime_checkable
class Observable(Protocol):
"""Signature for correlation observable getter."""

def __call__(self, atoms: Atoms, *args, **kwargs) -> float:
"""
Call the getter.
Parameters
----------
atoms : Atoms
Atoms object to extract values from.
*args : tuple
Additional positional arguments passed to getter.
**kwargs : dict
Additional kwargs passed getter.
"""


class CorrelationKwargs(TypedDict, total=True):
"""Arguments for on-the-fly correlations <ab>."""

#: observable a in <ab>, with optional args and kwargs
a: Observable | tuple[Observable, tuple, dict]
a: Observable
#: observable b in <ab>, with optional args and kwargs
b: Observable | tuple[Observable, tuple, dict]
b: Observable
#: name used for correlation in output
name: str
#: blocks used in multi-tau algorithm
Expand Down
87 changes: 86 additions & 1 deletion janus_core/helpers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
)
from rich.style import Style

from janus_core.helpers.janus_types import MaybeSequence, PathLike
from janus_core.helpers.janus_types import (
MaybeSequence,
PathLike,
SliceLike,
StartStopStep,
)


class FileNameMixin(ABC): # noqa: B024 (abstract-base-class-without-abstract-method)
Expand Down Expand Up @@ -432,3 +437,83 @@ def check_files_exist(config: dict, req_file_keys: Sequence[PathLike]) -> None:
# Only check if file key is in the configuration file
if not Path(config[file_key]).exists():
raise FileNotFoundError(f"{config[file_key]} does not exist")


def validate_slicelike(maybe_slicelike: SliceLike) -> None:
"""
Raise an exception if slc is not a valid SliceLike.
Parameters
----------
maybe_slicelike : SliceLike
Candidate to test.
Raises
------
ValueError
If maybe_slicelike is not SliceLike.
"""
if isinstance(maybe_slicelike, (slice, range, int)):
return
if isinstance(maybe_slicelike, tuple) and len(maybe_slicelike) == 3:
start, stop, step = maybe_slicelike
if (
(start is None or isinstance(start, int))
and (stop is None or isinstance(stop, int))
and isinstance(step, int)
):
return

raise ValueError(f"{maybe_slicelike} is not a valid SliceLike")


def slicelike_to_startstopstep(index: SliceLike) -> StartStopStep:
"""
Standarize `SliceLike`s into tuple of `start`, `stop`, `step`.
Parameters
----------
index : SliceLike
`SliceLike` to standardize.
Returns
-------
StartStopStep
Standardized `SliceLike` as `start`, `stop`, `step` triplet.
"""
validate_slicelike(index)
if isinstance(index, int):
if index == -1:
return (index, None, 1)
return (index, index + 1, 1)

if isinstance(index, (slice, range)):
return (index.start, index.stop, index.step)

return index


def selector_len(slc: SliceLike | list, selectable_length: int) -> int:
"""
Calculate the length of a selector applied to an indexable of a given length.
Parameters
----------
slc : Union[SliceLike, list]
The applied SliceLike or list for selection.
selectable_length : int
The length of the selectable object.
Returns
-------
int
Length of the result of applying slc.
"""
if isinstance(slc, int):
return 1
if isinstance(slc, list):
return len(slc)
start, stop, step = slicelike_to_startstopstep(slc)
if stop is None:
stop = selectable_length
return len(range(start, stop, step))
Loading

0 comments on commit b603535

Please sign in to comment.