Skip to content

Commit

Permalink
can now sample INITIAL_INFECTIONS, spent some time debugging transfor…
Browse files Browse the repository at this point in the history
…med distributions
  • Loading branch information
arik-shurygin committed Jan 18, 2024
1 parent 909c0fe commit 149c72d
Showing 1 changed file with 73 additions and 36 deletions.
109 changes: 73 additions & 36 deletions mechanistic_compartments.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,41 +92,8 @@ def __init__(self, **kwargs):
self.load_external_i_distributions()
# loads params used in self.external_i()

initial_infectious_count = (
self.INITIAL_INFECTIONS * self.INIT_INFECTED_DIST
)
initial_infectious_count_ages = jnp.sum(
initial_infectious_count,
axis=(
self.I_AXIS_IDX.hist,
self.I_AXIS_IDX.vax,
self.I_AXIS_IDX.strain,
),
)
initial_exposed_count = (
self.INITIAL_INFECTIONS * self.INIT_EXPOSED_DIST
)
initial_exposed_count_ages = jnp.sum(
initial_exposed_count,
axis=(
self.I_AXIS_IDX.hist,
self.I_AXIS_IDX.vax,
self.I_AXIS_IDX.strain,
),
)
# suseptible / partial susceptible = Total population - infected - exposed
initial_suseptible_count = (
self.POPULATION
- initial_infectious_count_ages
- initial_exposed_count_ages
)[:, np.newaxis, np.newaxis, np.newaxis] * self.INIT_IMMUNE_HISTORY

self.INITIAL_STATE = (
initial_suseptible_count, # s
initial_exposed_count, # e
initial_infectious_count, # i
jnp.zeros(initial_exposed_count.shape), # c
)
# load initial state using INIT_IMMUNE_HISTORY, INIT_INFECTED_DIST, and INIT_EXPOSED_DIST
self.INITIAL_STATE = self.load_initial_state(self.INITIAL_INFECTIONS)

self.solution = None

Expand Down Expand Up @@ -188,6 +155,8 @@ def get_args(
# either using the default sample_dist_dict, or the one provided by the user
# transform these distributions into numpyro samples.
for key, item in sample_dist_dict.items():
if key == "INITIAL_INFECTIONS":
continue
# sometimes you may want to sample the elements of a list, like R0 for strains
# check for that here:
if isinstance(item, list):
Expand Down Expand Up @@ -230,6 +199,9 @@ def get_args(
"SIGMA", 1 / args["EXPOSED_TO_INFECTIOUS"]
)
)
if "INITIAL_INFECTIONS" in args.keys():
# modifies initial state inplace, as it is not an arg passed into the ode.
self.load_initial_state(args["INITIAL_INFECTIONS"])
# since our last waning time is zero to account for last compartment never waning
# we include an if else statement to catch a division by zero error here.
waning_rates = [
Expand Down Expand Up @@ -514,13 +486,24 @@ def run(
t0 = 0.0
dt0 = 1.0
saveat = SaveAt(ts=jnp.linspace(t0, tf, int(tf) + 1))
initial_state = (
self.load_initial_state(
numpyro.sample(
"INITIAL_INFECTIONS",
sample_dist_dict["INITIAL_INFECTIONS"],
)
)
if "INITIAL_INFECTIONS" in sample_dist_dict.keys()
else self.INITIAL_STATE
)

solution = diffeqsolve(
term,
solver,
t0,
tf,
dt0,
self.INITIAL_STATE,
initial_state,
args=self.get_args(
sample=sample, sample_dist_dict=sample_dist_dict
),
Expand Down Expand Up @@ -1006,6 +989,60 @@ def zero_function(_):
pdf, loc=introduced_time, scale=7
)

def load_initial_state(self, initial_infections: float):
"""
a function which takes a number of initial infections, disperses them across infectious and exposed compartments
according to the INIT_INFECTED_DIST and INIT_EXPOSED_DIST distributions, then subtracts both those populations from the total population and
places the remaining individuals in the susceptible compartment, distributed according to the INIT_IMMUNE_HISTORY distribution.
Parameters
----------
initial_infections: the number of infections to disperse between infectious and exposed compartments.
Requires
----------
the following variables be loaded into self:
INIT_INFECTED_DIST: loaded in config or via load_init_infection_infected_and_exposed_dist_via_abm()
INIT_EXPOSED_DIST: loaded in config or via load_init_infection_infected_and_exposed_dist_via_abm()
INIT_IMMUNE_HISTORY: loaded in config or via load_immune_history_via_abm().
Modifies
----------
INITIAL_STATE: tuple(jnp.ndarray)
a tuple of len 4 representing the S, E, I, and C compartment population counts after model initialization.
"""
initial_infectious_count = initial_infections * self.INIT_INFECTED_DIST
initial_infectious_count_ages = jnp.sum(
initial_infectious_count,
axis=(
self.I_AXIS_IDX.hist,
self.I_AXIS_IDX.vax,
self.I_AXIS_IDX.strain,
),
)
initial_exposed_count = initial_infections * self.INIT_EXPOSED_DIST
initial_exposed_count_ages = jnp.sum(
initial_exposed_count,
axis=(
self.I_AXIS_IDX.hist,
self.I_AXIS_IDX.vax,
self.I_AXIS_IDX.strain,
),
)
# suseptible / partial susceptible = Total population - infected - exposed
initial_suseptible_count = (
self.POPULATION
- initial_infectious_count_ages
- initial_exposed_count_ages
)[:, np.newaxis, np.newaxis, np.newaxis] * self.INIT_IMMUNE_HISTORY
# self.INITIAL_STATE =
return (
initial_suseptible_count, # s
initial_exposed_count, # e
initial_infectious_count, # i
jnp.zeros(initial_exposed_count.shape), # c
)

def to_json(self, file=None):
"""
a simple method which takes self.config_file and dumps it into `file`.
Expand Down

0 comments on commit 149c72d

Please sign in to comment.