Skip to content

Commit

Permalink
removing unnecessary equality checks already handled by pydantic, reo…
Browse files Browse the repository at this point in the history
…rganizing file structure, fixing doc string and mypy errors
  • Loading branch information
arik-shurygin committed Feb 24, 2025
1 parent 0f54ddd commit fca607d
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 235 deletions.
12 changes: 12 additions & 0 deletions src/dynode/pydantic_config/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,13 @@
"""DynODE configuration module."""

from .config_definition import (
Compartment,
CompartmentalModel,
InferenceParams,
Initializer,
)

alias = Compartment
alias = CompartmentalModel
alias = InferenceParams
alias = Initializer
22 changes: 12 additions & 10 deletions src/dynode/pydantic_config/bins.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Bin types for ODE compartment models."""

from typing import Optional

from pydantic import (
BaseModel,
NonNegativeFloat,
Expand All @@ -16,27 +14,31 @@ class Bin(BaseModel):

name: str

def __eq__(self, value):
if type(self) is type(value):
return self.__dict__ == value.__dict__
else:
return False


class DiscretizedPositiveIntBin(Bin):
"""Bin with a distinct discretized positive int inclusive min/max."""

min_value: NonNegativeInt
max_value: NonNegativeInt
name: Optional[str] = None

def __init__(self, min_value, max_value, name=None):
"""Initialize a Discretized bin with inclusive min/max and sensible default name.
Parameters
----------
min_value : int
minimum value contained by the bin (inclusive)
max_value : int
maximum value contained by the bin (inclusive)
name : str, optional
name of the bin, by default f"{min_value}_{max_value}" if None
"""
if name is None:
name = f"{min_value}_{max_value}"
super().__init__(name=name, min_value=min_value, max_value=max_value)

@model_validator(mode="after")
def bin_valid_side(self) -> Self:
def _bin_valid_side(self) -> Self:
"""Assert that min_value <= max_value."""
assert self.min_value <= self.max_value
return self
Expand Down
70 changes: 58 additions & 12 deletions src/dynode/pydantic_config/config_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Compartment(BaseModel):

@field_validator("name", mode="before")
@classmethod
def make_attr_compliant(cls, value: str) -> str:
def _verify_names(cls, value: str) -> str:
"""Validate to ensure names are always lowercase and underscored."""
if value.replace("_", "").isalpha():
return value.lower()
else:
Expand All @@ -47,7 +48,7 @@ def make_attr_compliant(cls, value: str) -> str:
)

@model_validator(mode="after")
def shape_match(self) -> Self:
def _shape_match(self) -> Self:
"""Set default values if unspecified, asserts dimensions and values shape matches."""
target_values_shape: tuple[int, ...] = tuple(
[len(d_i) for d_i in self.dimensions]
Expand All @@ -64,7 +65,23 @@ def shape(self) -> tuple[int, ...]:
"""Get shape of the compartment."""
return tuple([len(d_i) for d_i in self.dimensions])

def __eq__(self, value):
def __eq__(self, value) -> bool:
"""Check for equality definitions between two Compartments.
Parameters
----------
value : Any
Other value to compare, usually another Compartment
Returns
-------
bool
whether or not the two compartments are equal in name and dimension structure.
Note
----
does not check the values of the compartments, only their dimensionality and definition.
"""
if isinstance(value, Compartment):
if self.name == value.name and len(self.dimensions) == len(
value.dimensions
Expand All @@ -76,14 +93,32 @@ def __eq__(self, value):
return True
return False

# def __setitem__(
# self, index: Union[int, slice, tuple], value: float
# ) -> None:
# """Experimental function that sets a value in the JAX array using functional update with slicing support."""
# self.values = self.values.at[index].set(value)
def __setitem__(self, index: int | slice | tuple, value: float) -> None:
"""Set Compartment value in a numpy-like way.
# def __getitem__(self, index: Union[int, slice, tuple]) -> Any:
# return self.values.at[index].get()
Parameters
----------
index : int | slice | tuple
index or slice or tuple to index the Compartment's values.
value : float
float to set values[index] to.
"""
self.values = self.values.at[index].set(value)

def __getitem__(self, index: int | slice | tuple) -> Array:
"""Get the Compartment's values at some index.
Parameters
----------
index : int | slice | tuple
index to look up.
Returns
-------
Any
value of the `self.values` tensor at that index.
"""
return self.values.at[index].get()


class Initializer(BaseModel):
Expand Down Expand Up @@ -136,7 +171,7 @@ class CompartmentalModel(BaseModel):
def validate_shared_compartment_dimensions(self) -> Self:
"""Validate that any dimensions with same name across compartments are equal."""
# quad-nested for loops are not ideal, but lists are very small so this should be fine
dimension_map = {}
dimension_map: dict[str, Dimension] = {}
for compartment in self.compartments:
for dimension in compartment.dimensions:
if dimension.name in dimension_map:
Expand All @@ -150,7 +185,16 @@ def validate_shared_compartment_dimensions(self) -> Self:
return self

@model_validator(mode="after")
def validate_immune_histories(self):
def _validate_immune_histories(self):
"""Validate that the immune history dimensions within each compartment are initialized from the same strain definitions.
Example
-------
If you have 2 strains, `x` and `y`,
- a `FullStratifiedImmuneHistory` should have 4 bins, `none`, `x`, `y`, `x_y`
- a `LastStrainImmuneHistory` should have 3 bins, `none`, `x`, `y`
- Neither class should bins with any other strain `z` or exclude one of the required bins.
"""
strains = self.parameters.transmission_params.strains
# gather all ImmuneHistory dimensions
for compartment in self.compartments:
Expand Down Expand Up @@ -193,6 +237,8 @@ def get_compartment(self, compartment_name: str) -> Compartment:


class InferenceProcess(BaseModel):
"""Inference process for fitting a CompartmentalModel to data."""

model_config = ConfigDict(arbitrary_types_allowed=True)
model: CompartmentalModel
# includes observation method, specified at runtime.
Expand Down
25 changes: 9 additions & 16 deletions src/dynode/pydantic_config/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,9 @@ def __len__(self):
"""Get len of a Dimension."""
return len(self.bins)

def __eq__(self, value):
if isinstance(value, Dimension):
if self.name == value.name and len(self.bins) == len(value.bins):
# check that all bins are in same order and have same names/values
for bin_l, bin_r in zip(self.bins, value.bins):
if bin_l != bin_r:
return False
return True
return False

@field_validator("bins", mode="after")
@classmethod
def check_bins_same_type(cls, bins) -> Self:
def _check_bins_same_type(cls, bins) -> Self:
"""Assert all bins are of same type and bins is not empty."""
assert len(bins) > 0, "can not have dimension with no bins"
bin_type = type(bins[0])
Expand All @@ -44,7 +34,7 @@ def check_bins_same_type(cls, bins) -> Self:

@field_validator("bins", mode="after")
@classmethod
def check_bin_names_unique(cls, bins: list[Bin]) -> list[Bin]:
def _check_bin_names_unique(cls, bins: list[Bin]) -> list[Bin]:
assert len(bins) > 0, "can not have dimension with no bins"
names = [b.name for b in bins]
assert len(set(names)) == len(
Expand All @@ -57,18 +47,19 @@ def check_bin_names_unique(cls, bins: list[Bin]) -> list[Bin]:
def sort_discretized_int_bins(cls, bins: list[Bin]) -> list[Bin]:
"""Assert that DiscretizedPositiveIntBin do not overlap and sorts them lowest to highest."""
assert len(bins) > 0, "can not have dimension with no bins"
if isinstance(bins[0], DiscretizedPositiveIntBin):
if all(isinstance(bin, DiscretizedPositiveIntBin) for bin in bins):
# sort age bins with lowest min_value first
bins: list[DiscretizedPositiveIntBin] = sorted(
bins_sorted = sorted(
bins, key=lambda b: b.min_value, reverse=False
)
# assert that bins dont overlap now they are sorted
assert all(
[
bins[i].max_value < bins[i + 1].min_value
bins_sorted[i].max_value < bins[i + 1].min_value
for i in range(len(bins) - 1)
]
), "DiscretizedPositiveIntBin within a dimension can not overlap."
return bins_sorted
return bins


Expand Down Expand Up @@ -98,7 +89,9 @@ def __init__(self, strains: list[Strain]) -> None:
all_immune_histories = [Bin(name="none")]
for strain in range(1, len(strain_names) + 1):
combs = combinations(strain_names, strain)
all_immune_histories.extend(["_".join(comb) for comb in combs])
all_immune_histories.extend(
[Bin(name="_".join(comb)) for comb in combs]
)

super().__init__(name="hist", bins=all_immune_histories)

Expand Down
14 changes: 13 additions & 1 deletion src/dynode/pydantic_config/params.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Module containing Parameter classes for storing DynODE parameters."""

from typing import List

from numpyro.distributions import Distribution
Expand All @@ -8,27 +10,37 @@


class SolverParams(BaseModel):
"""Parameters used by the ODE solver."""

model_config = ConfigDict(arbitrary_types_allowed=True)
ode_solver_rel_tolerance: PositiveFloat
ode_solver_abs_tolerance: PositiveFloat


class TransmissionParams(BaseModel):
"""Transmission Parameters for the respiratory model."""

model_config = ConfigDict(arbitrary_types_allowed=True)
strain_interactions: dict[str, dict[str, NonNegativeFloat | Distribution]]
strains: List[Strain]


class InferenceParams(BaseModel):
"""Parameters necessary for inference."""

model_config = ConfigDict(arbitrary_types_allowed=True)


class MCMCParams(InferenceParams):
"""Inference parameters specific to Markov Chain Monte Carlo (MCMC) fitting methods."""

model_config = ConfigDict(arbitrary_types_allowed=True)
inference_mcmc_steps: PositiveFloat


class SVIParams(InferenceParams):
"""Inference parameters specific to Stochastic Variational Inference (SVI) fitting methods."""

model_config = ConfigDict(arbitrary_types_allowed=True)


Expand All @@ -41,7 +53,7 @@ class Params(BaseModel):
transmission_params: TransmissionParams

def realize_distributions(self) -> Self:
"""Go through parameters and sample all distribution objects
"""Go through parameters and sample all distribution objects.
Returns
-------
Expand Down
5 changes: 5 additions & 0 deletions src/dynode/pydantic_config/pre_packaged/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""A module containing pre-packaged Configs for easy off the shelf use."""

from .covid_seip_config import SEIPCovidModel

alias = SEIPCovidModel
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,20 @@

from dynode.model_odes.seip_model import seip_ode

from .bins import AgeBin, Bin, WaneBin
from .config_definition import (
from ..bins import AgeBin, Bin, WaneBin
from ..config_definition import (
Compartment,
CompartmentalModel,
Initializer,
Params,
)
from .dimension import Dimension, LastStrainImmuneHistory, VaccinationDimension
from .params import SolverParams, TransmissionParams
from .strains import Strain
from ..dimension import (
Dimension,
LastStrainImmuneHistory,
VaccinationDimension,
)
from ..params import SolverParams, TransmissionParams
from ..strains import Strain


class SEIPCovidModel(CompartmentalModel):
Expand All @@ -43,18 +47,22 @@ def __init__(self):

@property
def s(self) -> Compartment:
"""The Susceptible compartment of the model."""
return self.compartments[0]

@property
def e(self) -> Compartment:
"""The Exposed compartment of the model."""
return self.compartments[1]

@property
def i(self) -> Compartment:
"""The Infectious compartment of the model."""
return self.compartments[2]

@property
def c(self) -> Compartment:
"""The Cumulative compartment of the model."""
return self.compartments[3]

def _get_param_store(self, strains: list[Strain]) -> Params:
Expand Down
Loading

0 comments on commit fca607d

Please sign in to comment.