Skip to content

Commit

Permalink
checkpoint a new pattern for modifying compartment values
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Feb 20, 2025
1 parent dee31e1 commit b2edb23
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/dynode/pydantic_config/config_definition.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Top level classes for DynODE configs."""

from datetime import date
from typing import Callable, List, Optional
from typing import Callable, List, Optional, Union

from jax import Array
from jax import numpy as jnp
Expand Down Expand Up @@ -50,6 +50,13 @@ def shape(self) -> tuple[int, ...]:
"""Get shape of the compartment."""
return tuple([len(d_i) for d_i in self.dimensions])

def __setitem__(
self, index: Union[int, slice, tuple], value: float
) -> None:
"""Experimental function that sets a value in the JAX array using functional update with slicing support."""
assert isinstance(self.values, Array), "values is not an array"
self.values = self.values.at[index].set(value)


class ParamStore(BaseModel):
"""Miscellaneous parameters of an ODE model."""
Expand Down
16 changes: 16 additions & 0 deletions src/dynode/pydantic_config/covid_seip_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ def __init__(self):
ode_function=seip_ode,
)

@property
def s(self) -> Compartment:
return self.compartments[0]

@property
def e(self) -> Compartment:
return self.compartments[1]

@property
def i(self) -> Compartment:
return self.compartments[2]

@property
def c(self) -> Compartment:
return self.compartments[3]

def _get_param_store(self, strains: list[Strain]) -> ParamStore:
return ParamStore(
strains=strains,
Expand Down
2 changes: 2 additions & 0 deletions src/dynode/pydantic_config/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class FullStratifiedImmuneHistory(Dimension):

def __init__(self, strains: list[Strain]) -> None:
"""Create a fully stratified immune history dimension."""
# TODO add a no-infection bin
strain_names = [s.strain_name for s in strains]
num_strains = len(strain_names)
all_immune_histories = []
Expand All @@ -73,6 +74,7 @@ class LastStrainImmuneHistory(Dimension):

def __init__(self, strains: list[Strain]) -> None:
"""Create an immune history dimension that only tracks last infected strain."""
# TODO add a no-infection bin
strain_names = [s.strain_name for s in strains]
bins: list[Bin] = [
CategoricalBin(name=state) for state in strain_names
Expand Down

0 comments on commit b2edb23

Please sign in to comment.