Skip to content

Commit

Permalink
adding a ton of validators and equivilance checking of compartments/d…
Browse files Browse the repository at this point in the history
…imensions/bins
  • Loading branch information
arik-shurygin committed Feb 22, 2025
1 parent 2fddefa commit 0f54ddd
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 75 deletions.
24 changes: 16 additions & 8 deletions src/dynode/pydantic_config/bins.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Bin types for ODE compartment models."""

from typing import Optional

from pydantic import (
BaseModel,
NonNegativeFloat,
Expand All @@ -12,20 +14,26 @@
class Bin(BaseModel):
"""A catch-all bin class meant to represent an individual cell of an ODE compartment."""

pass


class CategoricalBin(Bin):
"""Bin with a distinct name."""

name: str

def __eq__(self, value):
if type(self) is type(value):
return self.__dict__ == value.__dict__
else:
return False


class DiscretizedPositiveIntBin(Bin):
"""Bin with a distinct discretized positive int inclusive min/max."""

min_value: NonNegativeInt
max_value: NonNegativeInt
name: Optional[str] = None

def __init__(self, min_value, max_value, name=None):
if name is None:
name = f"{min_value}_{max_value}"
super().__init__(name=name, min_value=min_value, max_value=max_value)

@model_validator(mode="after")
def bin_valid_side(self) -> Self:
Expand All @@ -35,12 +43,12 @@ def bin_valid_side(self) -> Self:


class AgeBin(DiscretizedPositiveIntBin):
"""Age bin with inclusive mix and max age values."""
"""Age bin with inclusive mix and max age values, fills in name of bin for you."""

pass


class WaneBin(CategoricalBin):
class WaneBin(Bin):
"""Waning bin with a protection value and waning time in days."""

waning_time: NonNegativeInt
Expand Down
127 changes: 100 additions & 27 deletions src/dynode/pydantic_config/config_definition.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,30 @@
"""Top level classes for DynODE configs."""

from datetime import date
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional

from jax import Array
from jax import numpy as jnp
from numpyro.distributions import Distribution
from numpyro.infer import MCMC, SVI
from pydantic import (
BaseModel,
ConfigDict,
NonNegativeFloat,
Field,
PositiveFloat,
PositiveInt,
field_validator,
model_validator,
)
from typing_extensions import Self

from dynode import CompartmentGradiants

from .dimension import Dimension
from .strains import Strain
from .dimension import (
Dimension,
FullStratifiedImmuneHistory,
LastStrainImmuneHistory,
)
from .params import InferenceParams, Params


class Compartment(BaseModel):
Expand All @@ -30,15 +34,25 @@ class Compartment(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
name: str
dimensions: List[Dimension]
values: Optional[Array] = None
values: Array = Field(default_factory=lambda: jnp.array([]))

@field_validator("name", mode="before")
@classmethod
def make_attr_compliant(cls, value: str) -> str:
if value.replace("_", "").isalpha():
return value.lower()
else:
raise ValueError(
"the name field must not contain non-alpha chars with the exception of underscores"
)

@model_validator(mode="after")
def shape_match(self) -> Self:
"""Set default values if unspecified, asserts dimensions and values shape matches."""
target_values_shape: tuple[int, ...] = tuple(
[len(d_i) for d_i in self.dimensions]
)
if self.values is not None:
if bool(self.values.any()):
assert target_values_shape == self.values.shape
else:
# fill with default for now, values filled in at runtime.
Expand All @@ -50,23 +64,26 @@ def shape(self) -> tuple[int, ...]:
"""Get shape of the compartment."""
return tuple([len(d_i) for d_i in self.dimensions])

def __setitem__(
self, index: Union[int, slice, tuple], value: float
) -> None:
"""Experimental function that sets a value in the JAX array using functional update with slicing support."""
assert isinstance(self.values, Array), "values is not an array"
self.values = self.values.at[index].set(value)

def __eq__(self, value):
if isinstance(value, Compartment):
if self.name == value.name and len(self.dimensions) == len(
value.dimensions
):
# check both compartments have same dimensions in same order
for dim_l, dim_r in zip(self.dimensions, value.dimensions):
if dim_l != dim_r:
return False
return True
return False

class ParamStore(BaseModel):
"""Miscellaneous parameters of an ODE model."""
# def __setitem__(
# self, index: Union[int, slice, tuple], value: float
# ) -> None:
# """Experimental function that sets a value in the JAX array using functional update with slicing support."""
# self.values = self.values.at[index].set(value)

# allow users to pass custom types to ParamStore
model_config = ConfigDict(arbitrary_types_allowed=True)
strains: List[Strain]
strain_interactions: dict[str, dict[str, NonNegativeFloat | Distribution]]
ode_solver_rel_tolerance: PositiveFloat
ode_solver_abs_tolerance: PositiveFloat
# def __getitem__(self, index: Union[int, slice, tuple]) -> Any:
# return self.values.at[index].get()


class Initializer(BaseModel):
Expand Down Expand Up @@ -109,19 +126,75 @@ class CompartmentalModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
initializer: Initializer
compartments: List[Compartment]
parameters: ParamStore
parameters: Params
# passed to diffrax.diffeqsolve
ode_function: Callable[
[List[Compartment], PositiveFloat, ParamStore], CompartmentGradiants
[List[Compartment], PositiveFloat, Params], CompartmentGradiants
]
# includes observation method, specified at runtime.
inference_method: Optional[MCMC | SVI] = None

def get_compartment(self, compartment_name: str):
@model_validator(mode="after")
def validate_shared_compartment_dimensions(self) -> Self:
"""Validate that any dimensions with same name across compartments are equal."""
# quad-nested for loops are not ideal, but lists are very small so this should be fine
dimension_map = {}
for compartment in self.compartments:
for dimension in compartment.dimensions:
if dimension.name in dimension_map:
assert (
dimension == dimension_map[dimension.name]
), f"""dimension {dimension.name} has different definitions
across different compartments, if this intended, make
the dimensions have different names"""
else: # first time encountering this dimension name
dimension_map[dimension.name] = dimension
return self

@model_validator(mode="after")
def validate_immune_histories(self):
strains = self.parameters.transmission_params.strains
# gather all ImmuneHistory dimensions
for compartment in self.compartments:
for dimension in compartment.dimensions:
dim_class = type(dimension)
if (
dim_class is FullStratifiedImmuneHistory
or dim_class is LastStrainImmuneHistory
):
assert (
dim_class(strains) == dimension
), "Found immune states that dont correlate with strains from transmission_params"
return self

def get_compartment(self, compartment_name: str) -> Compartment:
"""Search the CompartmentModel and return a specific Compartment if it exists.
Parameters
----------
compartment_name : str
name of the compartment to return
Returns
-------
Compartment
Compartment class with matching name.
Raises
------
AssertionError
raise if `compartment_name` not found within `self.compartments`
"""
for compartment in self.compartments:
if compartment_name == compartment.name:
return compartment
raise AssertionError(
"Compartment with name %s not found in model, found only these names: %s"
% (compartment_name, str([c.name for c in self.compartments]))
)


class InferenceProcess(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
model: CompartmentalModel
# includes observation method, specified at runtime.
inference_method: Optional[MCMC | SVI] = None
inference_parameters: InferenceParams
34 changes: 19 additions & 15 deletions src/dynode/pydantic_config/covid_seip_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,16 @@

from dynode.model_odes.seip_model import seip_ode

from .bins import AgeBin, CategoricalBin, WaneBin
from .bins import AgeBin, Bin, WaneBin
from .config_definition import (
Compartment,
CompartmentalModel,
Initializer,
ParamStore,
Strain,
Params,
)
from .dimension import Dimension, LastStrainImmuneHistory, VaccinationDimension
from .params import SolverParams, TransmissionParams
from .strains import Strain


class SEIPCovidModel(CompartmentalModel):
Expand Down Expand Up @@ -56,8 +57,8 @@ def i(self) -> Compartment:
def c(self) -> Compartment:
return self.compartments[3]

def _get_param_store(self, strains: list[Strain]) -> ParamStore:
return ParamStore(
def _get_param_store(self, strains: list[Strain]) -> Params:
transmission_params = TransmissionParams(
strains=strains,
strain_interactions={
"omicron": {
Expand Down Expand Up @@ -106,8 +107,13 @@ def _get_param_store(self, strains: list[Strain]) -> ParamStore:
"jn1": 1.0,
},
},
ode_solver_rel_tolerance=1e-5,
ode_solver_abs_tolerance=1e-6,
)
return Params(
transmission_params=transmission_params,
solver_params=SolverParams(
ode_solver_rel_tolerance=1e-5,
ode_solver_abs_tolerance=1e-6,
),
)

def _get_strains(self) -> list[Strain]:
Expand All @@ -124,7 +130,7 @@ def _get_strains(self) -> list[Strain]:
),
infectious_period=7.0,
exposed_to_infectious=3.6,
vaccine_efficacy=[0, 0.35, 0.70],
vaccine_efficacy={0: 0, 1: 0.35, 2: 0.70},
),
Strain(
strain_name="ba2ba5",
Expand All @@ -136,7 +142,7 @@ def _get_strains(self) -> list[Strain]:
),
infectious_period=7.0,
exposed_to_infectious=3.6,
vaccine_efficacy=[0, 0.30, 0.60],
vaccine_efficacy={0: 0, 1: 0.30, 2: 0.60},
is_introduced=True,
introduction_time=dist.TruncatedNormal(
loc=20, scale=5, low=10
Expand All @@ -155,7 +161,7 @@ def _get_strains(self) -> list[Strain]:
),
infectious_period=7.0,
exposed_to_infectious=3.6,
vaccine_efficacy=[0, 0.30, 0.60],
vaccine_efficacy={0: 0, 1: 0.30, 2: 0.60},
is_introduced=True,
introduction_time=dist.TruncatedNormal(
loc=230, scale=5, low=190
Expand All @@ -174,7 +180,7 @@ def _get_strains(self) -> list[Strain]:
),
infectious_period=7.0,
exposed_to_infectious=3.6,
vaccine_efficacy=[0, 0.095, 0.19],
vaccine_efficacy={0: 0, 1: 0.095, 2: 0.19},
is_introduced=True,
introduction_time=dist.TruncatedNormal(
loc=640, scale=5, low=600
Expand All @@ -198,7 +204,7 @@ def _get_compartments(self, strains: list[Strain]) -> list[Compartment]:
)
immune_history_dimension = LastStrainImmuneHistory(strains=strains)
vaccination_dimension = VaccinationDimension(
max_ordinal_vaccinations=2, seasonal_vaccination=True
max_ordinal_vaccinations=2, seasonal_vaccination=False
)
waning_dimension = Dimension(
name="wane",
Expand All @@ -211,9 +217,7 @@ def _get_compartments(self, strains: list[Strain]) -> list[Compartment]:
)
infecting_strain_dimension = Dimension(
name="strain",
bins=[
CategoricalBin(name=strain.strain_name) for strain in strains
],
bins=[Bin(name=strain.strain_name) for strain in strains],
)
s_compartment = Compartment(
name="s",
Expand Down
Loading

0 comments on commit 0f54ddd

Please sign in to comment.