Skip to content

Commit

Permalink
adding verbose flag to dynode runner save_inference_timelines (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin authored Dec 9, 2024
1 parent f26261c commit 4edfe5e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/dynode/dynode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,7 @@ def save_inference_timelines(
extra_timelines: pd.DataFrame = None,
tf: Union[int, None] = None,
external_particle: dict[str, Array] = {},
verbose: bool = False,
) -> str:
"""saves history of inferer sampled values for use by the azure visualizer.
saves CSV file to `self.azure_output_dir/timeline_filename`.
Expand Down Expand Up @@ -551,6 +552,8 @@ def save_inference_timelines(
For example, loading a checkpoint.json containing saved posteriors from an Azure Batch job.
expects keys that match those given to `numpyro.sample` often from
inference_algo.get_samples(group_by_chain=True).
verbose: bool, optional
whether or not to pring out the current chain_particle value being executed
Returns
-------
Expand Down Expand Up @@ -596,6 +599,7 @@ def save_inference_timelines(
chain_particle_pairs,
tf=tf,
external_particle=external_particle,
verbose=verbose,
)
for (chain, particle), sol_dct in posteriors.items():
# content of `sol_dct` depends on return value of inferer.likelihood func
Expand Down
24 changes: 18 additions & 6 deletions src/dynode/mechanistic_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ def load_posterior_particle(
particles: Union[tuple[int, int], list[tuple[int, int]]],
tf: Union[int, None] = None,
external_particle: dict[str, jax.Array] = {},
verbose: bool = False,
) -> dict[
tuple[int, int],
dict[str, Union[Solution, jax.Array, dict[str, jax.Array]]],
Expand All @@ -306,6 +307,7 @@ def load_posterior_particle(
if `external_posteriors` are specified, uses them instead of self.inference_algo.get_samples()
to load static particle values.
Parameters
------------
particles: Union[tuple[int, int], list[tuple[int, int]]]
Expand All @@ -319,6 +321,9 @@ def load_posterior_particle(
For example, loading a checkpoint.json containing saved posteriors from a different run.
expects keys that match those given to `numpyro.sample` often from
inference_algo.get_samples(group_by_chain=True).
verbose: bool, optional
whether or not to pring out the current chain_particle value being executed
Returns
---------------
Expand All @@ -329,18 +334,20 @@ def load_posterior_particle(
Example
--------------
<insert 2 chain inference above>
load_posterior_particle([(0, 100), [1, 120],...]) = {(0, 100): {solution: diffrax.Solution, "posteriors": {...}},
(1, 120): {solution: diffrax.Solution, "posteriors": {...}} ...}
`load_posterior_particle([(0, 100), [1, 120],...]) = {(0, 100): {solution: diffrax.Solution, "posteriors": {...}},
(1, 120): {solution: diffrax.Solution, "posteriors": {...}} ...}`
Note
------------
Very important note if you choose to use `external_posteriors`. In the scenario
this instance of `MechanisticInferer.likelihood` samples parameters not named in `external_posteriors`
they will be RESAMPLED AT RANDOM. This method will not error and will instead fill in those
missing samples according to the PRNGKey seeded with self.config.INFERENCE_PRNGKEY as well as
unique salting of each chain_particle combination.
they will be RESAMPLED according to the distribution passed in the config.
This method will also salt the RNG key used on the prior according to the
chain & particule numbers currently being run.
This may be useful to you if you wish to obtain confidence intervals by varying a particular value.
This may be useful to you if you wish to fit upon some data, then introduce
a new varying parameter over the posteriors (often during projection).
"""
# if its a single particle, convert to len(1) list for simplicity
if isinstance(particles, tuple):
Expand Down Expand Up @@ -374,6 +381,11 @@ def load_posterior_particle(
for particle in particles:
# get the particle chain and number
chain_num, particle_num = particle
if verbose:
print(
"Executing (chain, particle): (%s, %s)"
% (str(chain_num), str(particle_num))
)
single_particle_samples = {}
# go through each posterior and select that specific chain and particle
for param in posterior_samples.keys():
Expand Down

0 comments on commit 4edfe5e

Please sign in to comment.