Skip to content

Commit

Permalink
Saving timesteps parameter (#305)
Browse files Browse the repository at this point in the history
* checkpoint, now save compartment sizes at each date in COMPARTMENT_SAVE_DATES

* bugfix sim_day to _checkpoint_compartment_sizes()

* checkpoint trying to debug OOM issues when saving timesteps

* excluding all timesteps variable

* timesteps feature working

* bugfix boolean logic or -> and
  • Loading branch information
arik-shurygin authored Dec 17, 2024
1 parent c665be2 commit c758a9d
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 30 deletions.
10 changes: 10 additions & 0 deletions src/dynode/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,16 @@ class is accepted to modify/create the downstream parameters.
# "validate": do_nothing,
"type": lambda s: datetime.datetime.strptime(s, "%Y-%m-%d").date(),
},
{
# list[date] on which the user wishes to save the state of each
# compartment, final_timesteps automatically
"name": "COMPARTMENT_SAVE_DATES",
# "validate": do_nothing,
# type list[date]
"type": lambda lst: [
datetime.datetime.strptime(s, "%Y-%m-%d").date() for s in lst
],
},
{
"name": "VACCINATION_SEASON_CHANGE",
# "validate": do_nothing,
Expand Down
50 changes: 36 additions & 14 deletions src/dynode/dynode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ def save_inference_posteriors(
self,
inferer: MechanisticInferer,
save_filename="checkpoint.json",
exclude_prefixes=["final_timestep"],
exclude_prefixes=["timestep"],
save_chains_plot=True,
save_pairs_correlation_plot=True,
) -> None:
Expand All @@ -443,7 +443,8 @@ def save_inference_posteriors(
exclude_prefixes: list[str], optional
a list of strs that, if found in a sample name,
are exlcuded from the saved json. This is common for large logging
info that will bloat filesize like, by default ["final_timestep"]
info that will bloat filesize like, by default ["timestep"]
to exclude all timestep deterministic variables.
save_chains_plot: bool, optional
whether to save accompanying mcmc chains plot, by default True
save_pairs_correlation_plot: bool, optional
Expand Down Expand Up @@ -480,40 +481,61 @@ def save_inference_final_timesteps(
self,
inferer: MechanisticInferer,
save_filename="final_timesteps.json",
final_timestep_identifier="final_timestep",
):
"""saves the `final_timestep` posterior, if it is found in mcmc.get_samples(), otherwise raises a warning
and saves nothing
"""saves the `final_timestep` posterior, if it is found in
mcmc.get_samples(), otherwise raises a warning and saves nothing
Parameters
----------
inferer : MechanisticInferer
inferer that was run with `inferer.infer()`
save_filename : str, optional
output filename, by default "final_timesteps.json"
output filename, by default "timesteps.json"
final_timestep_identifier : str, optional
prefix attached to the final_timestep parameter, by default "final_timestep"
prefix attached to the final_timestep parameter, by default "timestep"
"""
self.save_inference_timesteps(
inferer, save_filename, timestep_identifier="final_timestep"
)

def save_inference_timesteps(
self,
inferer: MechanisticInferer,
save_filename="timesteps.json",
timestep_identifier="timestep",
):
"""saves all `timestep` posteriors, if they are found in
mcmc.get_samples(), otherwise raises a warning and saves nothing
Parameters
----------
inferer : MechanisticInferer
inferer that was run with `inferer.infer()`
save_filename : str, optional
output filename, by default "timesteps.json"
step_identifier : str, optional
identifying token attached to any timestep parameter, by default "timestep"
"""
# if inference complete, convert jnp/np arrays to list, then json dump
if inferer.infer_complete:
samples = inferer.inference_algo.get_samples(group_by_chain=True)
final_timesteps = {
timesteps = {
name: timesteps
for name, timesteps in samples.items()
if final_timestep_identifier in name
if timestep_identifier in name
}
# if it is empty, warn the user, save nothing
if final_timesteps:
if timesteps:
save_path = os.path.join(self.azure_output_dir, save_filename)
self._save_samples(final_timesteps, save_path)
self._save_samples(timesteps, save_path)
else:
warnings.warn(
"attempting to call `save_inference_final_timesteps` but failed to find any final_timesteps with prefix %s"
% final_timestep_identifier
"attempting to call `save_inference_timesteps` but failed to find any timesteps with prefix %s"
% timestep_identifier
)
else:
warnings.warn(
"attempting to call `save_inference_final_timesteps` before inference is complete. Something is likely wrong..."
"attempting to call `save_inference_timesteps` before inference is complete. Something is likely wrong..."
)

def save_inference_timelines(
Expand Down
50 changes: 38 additions & 12 deletions src/dynode/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
observed metrics.
"""

import datetime
import json
from typing import Union

Expand All @@ -21,6 +22,7 @@
from .abstract_parameters import AbstractParameters
from .config import Config
from .mechanistic_runner import MechanisticRunner
from .utils import date_to_sim_day


class MechanisticInferer(AbstractParameters):
Expand Down Expand Up @@ -182,18 +184,7 @@ def likelihood(
dct = self.run_simulation(tf)
solution = dct["solution"]
predicted_metrics = dct["hospitalizations"]
numpyro.deterministic(
"final_timestep_s", solution.ys[self.config.COMPARTMENT_IDX.S][-1]
)
numpyro.deterministic(
"final_timestep_e", solution.ys[self.config.COMPARTMENT_IDX.E][-1]
)
numpyro.deterministic(
"final_timestep_i", solution.ys[self.config.COMPARTMENT_IDX.I][-1]
)
numpyro.deterministic(
"final_timestep_c", solution.ys[self.config.COMPARTMENT_IDX.C][-1]
)
self._checkpoint_compartment_sizes(solution)
predicted_metrics = jnp.maximum(predicted_metrics, 1e-6)
numpyro.sample(
"incidence",
Expand Down Expand Up @@ -247,6 +238,41 @@ def _debug_likelihood(self, **kwargs) -> bx.Model:
)
return bx_model

def _checkpoint_compartment_sizes(self, solution: Solution):
"""marks the final_timesteps parameters as well as any
requested dates from self.config.COMPARTMENT_SAVE_DATES if the
parameter exists. Skipping over any invalid dates.
This method does not actually save the compartment sizes to a file,
instead it stores the values within `self.inference_algo.get_samples()`
so that they may be later saved by self.checkpoint() or by the user.
Parameters
----------
solution : diffrax.Solution
a diffrax Solution object returned by solving ODEs, most often
retrieved by `self.run_simulation()`
"""
for compartment in self.config.COMPARTMENT_IDX:
numpyro.deterministic(
"final_timestep_%s" % compartment.name,
solution.ys[compartment][-1],
)
for date in getattr(self.config, "COMPARTMENT_SAVE_DATES", []):
date: datetime.date
date_str = date.strftime("%Y_%m_%d")
sim_day = date_to_sim_day(date, self.config.INIT_DATE)
# ensure user requests a day we actually have in `solution`
if sim_day >= 0 and sim_day < len(
solution.ys[self.config.COMPARTMENT_IDX.S]
):
for compartment in self.config.COMPARTMENT_IDX:
numpyro.deterministic(
"%s_timestep_%s" % (date_str, compartment.name),
solution.ys[compartment][sim_day],
)

def checkpoint(
self, checkpoint_path: str, group_by_chain: bool = True
) -> None:
Expand Down
8 changes: 4 additions & 4 deletions src/dynode/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def plot_checkpoint_inference_correlation_pairs(
for key, val in posteriors.items()
}
posteriors: dict[str, np.ndarray] = flatten_list_parameters(posteriors)
# drop any final_timestep parameters in case they snuck in
posteriors = drop_keys_with_substring(posteriors, "final_timestep")
# drop any timestep parameters in case they snuck in
posteriors = drop_keys_with_substring(posteriors, "timestep")
number_of_samples = posteriors[list(posteriors.keys())[0]].shape[1]
# if we are dealing with many samples per chain,
# narrow down to max_samples_calculated samples per chain
Expand Down Expand Up @@ -435,8 +435,8 @@ def plot_mcmc_chains(
for key, val in samples.items()
}
samples: dict[str, np.ndarray] = flatten_list_parameters(samples)
# drop any final_timestep parameters in case they snuck in
samples = drop_keys_with_substring(samples, "final_timestep")
# drop any timestep parameters in case they snuck in
samples = drop_keys_with_substring(samples, "timestep")
param_names = list(samples.keys())
num_params = len(param_names)
num_chains = samples[param_names[0]].shape[0]
Expand Down

0 comments on commit c758a9d

Please sign in to comment.