Skip to content

Commit

Permalink
removing cholesky triangle priors from previous epoch loading techniq…
Browse files Browse the repository at this point in the history
…ue (#295)

* removing cholesky triangle priors from previous epoch loading technique

* removing doc comments referencing old method
  • Loading branch information
arik-shurygin authored Nov 18, 2024
1 parent f3563db commit 08582bf
Showing 1 changed file with 12 additions and 122 deletions.
134 changes: 12 additions & 122 deletions src/dynode/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
"""

import json
import warnings
from typing import Union

import bayeux as bx
Expand All @@ -16,10 +15,9 @@
from diffrax import Solution # type: ignore
from jax.random import PRNGKey
from numpyro import distributions as Dist
from numpyro.diagnostics import summary # type: ignore
from numpyro.infer import MCMC, NUTS # type: ignore

from . import SEIC_Compartments, utils
from . import SEIC_Compartments
from .abstract_parameters import AbstractParameters
from .config import Config
from .mechanistic_runner import MechanisticRunner
Expand All @@ -38,32 +36,31 @@ def __init__(
distributions_path: str,
runner: MechanisticRunner,
initial_state: SEIC_Compartments,
prior_inferer: MCMC = None,
):
distributions_json = open(distributions_path, "r").read()
global_json = open(global_variables_path, "r").read()
self.config = Config(global_json).add_file(distributions_json)
self.runner = runner
self.INITIAL_STATE = initial_state
self.infer_complete = False
self.set_infer_algo(prior_inferer=prior_inferer)
self.set_infer_algo()
self.retrieve_population_counts()
self.load_vaccination_model()
self.load_contact_matrix()

def set_infer_algo(
self, prior_inferer: MCMC = None, inferer_type: str = "mcmc"
) -> None:
"""
Sets the inferer's inference algorithm and sampler.
If passed a previous inferer of the same inferer_type, uses posteriors to aid in the definition of new priors.
This does require special configuration parameters to aid in transition between sequential inferers.
def set_infer_algo(self, inferer_type: str = "mcmc") -> None:
"""Sets the inferer's inference algorithm and sampler.
Parameters
----------
prior_inferer: None, numpyro.infer.MCMC
the inferer algorithm of the previous sequential call to inferer.infer
use posteriors in this previous call to help define the priors in the current call.
inferer_type : str, optional
infer algo you wish to use, by default "mcmc"
Raises
------
NotImplementedError
if passed `inferer_type` that is not yet supported, raises
NotImplementedError
"""
supported_infer_algos = ["mcmc"]
if inferer_type.lower().strip() not in supported_infer_algos:
Expand All @@ -87,13 +84,6 @@ def set_infer_algo(
num_chains=self.config.INFERENCE_NUM_CHAINS,
progress_bar=self.config.INFERENCE_PROGRESS_BAR,
)
if prior_inferer is not None:
# may want to look into this here:
# https://num.pyro.ai/en/stable/mcmc.html#id7
assert isinstance(
prior_inferer, MCMC
), "the previous inferer is not of the same type."
self.set_posteriors_if_exist(prior_inferer)

def _get_predictions(
self, parameters: dict, solution: Solution
Expand Down Expand Up @@ -211,106 +201,6 @@ def likelihood(
obs=obs_metrics,
)

def set_posteriors_if_exist(self, prior_inferer: MCMC) -> None:
"""
Given a `prior_inferer` object look at its samples, check to make sure that
each parameter sampled has converging chains, then calculate the mean of
each of the parameters samples, as well as the covariance between all of the parameter
posterior distributions.
To exclude certain chains from use in posteriors use `DROP_CHAINS`
config argument as a list of chain indexes.
To exclude certain parameters from use in posteriors because you are not
sampling them this epoch, use the `DROP_POSTERIOR_PARAMETERS` config argument
as a list of sample names as they appear in `prior_inferer.print_summary()`
Parameters
-----------
prior_inferer: MCMC
the inferer algorithm used in the previous epoch, or None.
Updates
-----------
self.prior_inferer_particle_means : `np.ndarray`
non-dropped parameter means across all non-dropped chains
self.prior_inferer_particle_cov : `np.ndarray`
non-dropped parameters covariance across all non-dropped chains
self.prior_inferer_param_names : `list[str]`
non-dropped parameter names
self.cholesky_triangle_matrix : `jnp.ndarray`
a cholesky bottom triangle matrix for each non-dropped parameter.
Used in cholesky decomposition of a multivariate normal distribution
Returns
-----------
None
"""
if prior_inferer is not None:
# get all the samples from each chain run in previous inference
samples = prior_inferer.get_samples(group_by_chain=True)
# if a user does not want to use posteriors for certain parameters
# they can drop them using the DROP_POSTERIOR_PARAMETERS keyword
for parameter in getattr(
self.config, "DROP_POSTERIOR_PARAMETERS", []
):
samples.pop(parameter, None)
# flatten any parameters that are created via numpyro.plate
# these parameters add a dimensions to `samples` values, and mess with things
samples = utils.flatten_list_parameters(samples)
dropped_chains = []
if hasattr(self.config, "DROP_CHAINS"):
dropped_chains = self.config.DROP_CHAINS
# if user specified they want certain chain indexes dropped, do that
samples = utils.drop_sample_chains(samples, dropped_chains)
# create a summary of these chains to calculate divergence of chains etc
sample_summaries = summary(samples)
# do some sort of testing to ensure the chains are properly converging.
# lets all flatten all the samples from all chains together
samples_array_flattened = np.array([])
for sample in samples.keys():
sample_summary = sample_summaries[sample]
divergent_chains = False
if sample_summary["r_hat"] > 1.05:
warnings.warn(
"WARNING: the inferer has detected divergent chains in the %s parameter "
"being passed as input into this epoch. "
"Diverging chains can cause summary posterior distributions to not "
"accurately reflect the true posterior distribution "
"you may use the DROP_CHAINS configuration parameter to "
"drop the offending chain " % str(sample),
RuntimeWarning,
)
divergent_chains = True
# now we add the parameter in flattened form for later
if not len(samples_array_flattened):
samples_array_flattened = np.array(
[samples[sample].flatten()]
)
else:
samples_array_flattened = np.concatenate(
(
samples_array_flattened,
np.array([samples[sample].flatten()]),
),
axis=0,
)
# if we have divergent chains, warn and show them to the user
if divergent_chains:
utils.plot_sample_chains(samples)
# samples_array_flattened is now of shape (P, N*M)
# for P parameters, N samples per chain and M chains per parameter
self.prior_inferer_particle_means = np.mean(
samples_array_flattened, axis=1
)
self.prior_inferer_particle_cov = np.cov(samples_array_flattened)
self.prior_inferer_param_names = list(samples.keys())
self.cholesky_triangle_matrix = jnp.linalg.cholesky(
self.prior_inferer_particle_cov
)

return None

def infer(self, obs_metrics: jax.Array) -> MCMC:
"""
Infer parameters given priors inside of self.config,
Expand Down

0 comments on commit 08582bf

Please sign in to comment.