diff --git a/mechanistic_model/mechanistic_inferer.py b/mechanistic_model/mechanistic_inferer.py index ec1ba895..a920092a 100644 --- a/mechanistic_model/mechanistic_inferer.py +++ b/mechanistic_model/mechanistic_inferer.py @@ -13,10 +13,13 @@ import numpy as np import numpyro # type: ignore from diffrax import Solution +from jax import random from jax.random import PRNGKey from numpyro import distributions as Dist from numpyro.diagnostics import summary # type: ignore +from numpyro.handlers import seed, trace # type: ignore from numpyro.infer import MCMC, NUTS # type: ignore +from numpyro.infer.util import potential_energy import mechanistic_model.utils as utils from config.config import Config @@ -408,8 +411,7 @@ def load_posterior_particle( "run self.infer() first to produce posterior particles or pass externally produced particles" ) if tf is None: - # run for same amount of timesteps as given in inference - # given to exists since self.infer_complete is True + # run for same amount of timesteps as given print(e) # given to exists since self.infer_complete is True if hasattr(self, "inference_timesteps"): tf = self.inference_timesteps # unless user is using external_posterior, we may have not inferred yet @@ -466,3 +468,58 @@ def _load_posterior_single_particle( sol_dct = substituted_model(tf=tf, infer_mode=False) sol_dct["posteriors"] = substituted_model.data return sol_dct + + def stresstest( + self, N: int, scale: float = 1, **kwargs + ) -> jax.typing.ArrayLike: + """ + Perform a stress test on the model by generating random parameter values and + checking if the model fails for each parameter set. Model calls use `numpyro.infer.util.potential_energy` + with random parameters in unconstrained domain. Any parameter set causing a sample fail, + or returning `NaN` or `Inf` potential are returned. + + Parameters + ------------ + N (int): + The number of random parameter sets to generate for stress testing. + scale (float, optional): + A scaling factor to apply to the random parameter values. Defaults to 1. + kwargs: + Key word arguments passed to `loglikelihood`. + + Returns + --------------- + List[Dict[str, Any]]: A list of failing parameter sets, where each parameter set is a dictionary + mapping parameter keys to their corresponding values. + """ + # Execute the model to collect parameter keys + exec_trace = trace( + seed( + self.likelihood, + jax.random.PRNGKey(self.config.INFERENCE_PRNGKEY), + ) + ).get_trace(kwargs) + # Generate random parameter values with cauchy distribution + rand_vars = [ + random.cauchy(rk, (len(exec_trace.keys()),)) + for rk in random.split( + jax.random.PRNGKey(self.config.INFERENCE_PRNGKEY), N + ) + ] + rand_params = [ + {key: scale * x[i] for i, key in enumerate(exec_trace.keys())} + for x in rand_vars + ] + failing_params = [] + for param in rand_params: + try: + # potential_energy should raise an exception if the model fails + # and ingests parameters on the unconstrained domain + pe = potential_energy(self.likelihood, {}, {}, param) + if bool(jnp.isnan(pe)): + failing_params.append(param) + if bool(jnp.isinf(pe)): + failing_params.append(param) + except Exception as _: + failing_params.append(param) + return failing_params diff --git a/tests/test_inferer.py b/tests/test_inferer.py index 47ac2874..abe60e64 100644 --- a/tests/test_inferer.py +++ b/tests/test_inferer.py @@ -80,6 +80,13 @@ def test_load_posterior_particle(): ), "load_posterior_particle produced different timeline shapes than what was fit on" +def test_stresstest_runs(): + failed_params = inferer.stresstest(1000, tf=10) + assert ( + len(failed_params) >= 0 + ), "Params causing failure not returning as a list" + + def test_external_posteriors(): load_across_chains = [ (chain, 0) for chain in range(inferer.config.INFERENCE_NUM_CHAINS)