Skip to content

Commit

Permalink
adding some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Jan 18, 2024
1 parent 149c72d commit a8156fd
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions mechanistic_compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def get_args(
# either using the default sample_dist_dict, or the one provided by the user
# transform these distributions into numpyro samples.
for key, item in sample_dist_dict.items():
# if user wants to sample model initial infections, do that outside of get_args()
if key == "INITIAL_INFECTIONS":
continue
# sometimes you may want to sample the elements of a list, like R0 for strains
Expand Down Expand Up @@ -486,6 +487,7 @@ def run(
t0 = 0.0
dt0 = 1.0
saveat = SaveAt(ts=jnp.linspace(t0, tf, int(tf) + 1))
# if the user wants to sample model initial infections, do it here
initial_state = (
self.load_initial_state(
numpyro.sample(
Expand Down Expand Up @@ -1006,11 +1008,12 @@ def load_initial_state(self, initial_infections: float):
INIT_EXPOSED_DIST: loaded in config or via load_init_infection_infected_and_exposed_dist_via_abm()
INIT_IMMUNE_HISTORY: loaded in config or via load_immune_history_via_abm().
Modifies
Returns
----------
INITIAL_STATE: tuple(jnp.ndarray)
a tuple of len 4 representing the S, E, I, and C compartment population counts after model initialization.
"""
# create population distribution using INIT_INFECTED_DIST, then sum them for later use
initial_infectious_count = initial_infections * self.INIT_INFECTED_DIST
initial_infectious_count_ages = jnp.sum(
initial_infectious_count,
Expand All @@ -1020,6 +1023,7 @@ def load_initial_state(self, initial_infections: float):
self.I_AXIS_IDX.strain,
),
)
# create population distribution using INIT_EXPOSED_DIST, then sum them for later use
initial_exposed_count = initial_infections * self.INIT_EXPOSED_DIST
initial_exposed_count_ages = jnp.sum(
initial_exposed_count,
Expand All @@ -1029,13 +1033,13 @@ def load_initial_state(self, initial_infections: float):
self.I_AXIS_IDX.strain,
),
)
# suseptible / partial susceptible = Total population - infected - exposed
# suseptible / partial susceptible = Total population - infected_count - exposed_count
initial_suseptible_count = (
self.POPULATION
- initial_infectious_count_ages
- initial_exposed_count_ages
)[:, np.newaxis, np.newaxis, np.newaxis] * self.INIT_IMMUNE_HISTORY
# self.INITIAL_STATE =
# cumulative count always starts at zero
return (
initial_suseptible_count, # s
initial_exposed_count, # e
Expand Down

0 comments on commit a8156fd

Please sign in to comment.