From ef160840b8901a5e54eb3fbee6f91b6aeaef4ace Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Thu, 26 Sep 2024 20:48:24 +0000 Subject: [PATCH 1/5] checkpoint, updating the inferer_projection.py and changing the foi_suscept function to do slightly different defintion --- .../inferer_projection.py | 912 ++++++++++++++++++ .../postprocess_states_0.py | 82 ++ exp/projections_ihr2_2404_2507/run_task.py | 142 +++ .../scenarios_setup.py | 313 ++++++ .../seip_model_flatten_immune_hist.py | 148 +++ src/resp_ode/utils.py | 16 +- 6 files changed, 1608 insertions(+), 5 deletions(-) create mode 100644 exp/projections_ihr2_2404_2507/inferer_projection.py create mode 100644 exp/projections_ihr2_2404_2507/postprocess_states_0.py create mode 100644 exp/projections_ihr2_2404_2507/run_task.py create mode 100644 exp/projections_ihr2_2404_2507/scenarios_setup.py create mode 100644 src/resp_ode/model_odes/seip_model_flatten_immune_hist.py diff --git a/exp/projections_ihr2_2404_2507/inferer_projection.py b/exp/projections_ihr2_2404_2507/inferer_projection.py new file mode 100644 index 00000000..360b7a17 --- /dev/null +++ b/exp/projections_ihr2_2404_2507/inferer_projection.py @@ -0,0 +1,912 @@ +import os +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import numpyro +import numpyro.distributions as Dist +from diffrax import Solution +from jax.random import PRNGKey +from jax.scipy.stats.norm import pdf +from numpyro.infer import MCMC + +from resp_ode import ( + MechanisticInferer, + MechanisticRunner, + SEIC_Compartments, + utils, +) +from resp_ode.config import Config + + +class ProjectionParameters(MechanisticInferer): + UPSTREAM_PARAMETERS = [ + "INIT_DATE", + "CONTACT_MATRIX", + "NUM_STRAINS", + "NUM_AGE_GROUPS", + "NUM_WANING_COMPARTMENTS", + "WANING_PROTECTIONS", + "MAX_VACCINATION_COUNT", + "STRAIN_INTERACTIONS", + "VACCINE_EFF_MATRIX", + "BETA_TIMES", + "STRAIN_R0s", + "INFECTIOUS_PERIOD" "EXPOSED_TO_INFECTIOUS", + "INTRODUCTION_TIMES", + "INTRODUCTION_SCALES", + "INTRODUCTION_PCTS" "INITIAL_INFECTIONS_SCALE", + "CONSTANT_STEP_SIZE", + "SEASONALITY_AMPLITUDE", + "SEASONALITY_SECOND_WAVE", + "SEASONALITY_SHIFT", + "MIN_HOMOLOGOUS_IMMUNITY", + "R0_MULTIPLIER", + "WANING_TIMES", + ] + + def __init__( + self, + global_variables_path: str, + distributions_path: str, + runner: MechanisticRunner, + prior_inferer: MCMC = None, + ): + """A specialized init method which does not take an initial state, this is because + posterior particles will contain the initial state used.""" + distributions_json = open(distributions_path, "r").read() + global_json = open(global_variables_path, "r").read() + self.config = Config(global_json).add_file(distributions_json) + self.runner = runner + self.infer_complete = False # flag once inference completes + self.set_infer_algo(prior_inferer=prior_inferer) + self.load_vaccination_model() + self.load_contact_matrix() + + def load_vaccination_model(self): + """ + an overridden version of the vaccine model so we can load + state-specific vaccination splines using the REGIONS parameter + """ + vax_spline_filename = "spline_fits_%s.csv" % ( + self.config.REGIONS[0].lower().replace(" ", "_") + ) + vax_spline_path = os.path.join( + self.config.VACCINATION_MODEL_DATA, vax_spline_filename + ) + self.config.VACCINATION_MODEL_DATA = vax_spline_path + super().load_vaccination_model() + + def infer( + self, + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_lsd, + obs_sero_days, + obs_var_prop, + obs_var_days, + obs_var_sd, + ): + """ + OVERRIDEN TO ADD MORE DATA STREAMS TO COMPARE AGAINST + Infer parameters given priors inside of self.config, returns an inference_algo object with posterior distributions for each sampled parameter. + + + Parameters + ---------- + obs_hosps: jnp.ndarray: weekly hosp incidence values from NHSN + obs_hosps_days: list[int] the sim day on which each obs_hosps value is measured. + for example obs_hosps[0] = 0 = self.config.INIT_DATE + obs_sero_lmean: jnp.ndarray: observed seroprevalence in logit scale + obs_sero_lsd: jnp.ndarray: standard deviation of logit seroprevalence (use this to + control the magnitude of uncertainty / weightage of fitting) + obs_sero_days: list[int] the sim day on which each obs_sero value is measured. + e.g., [9, 23, ...] meaning that we have data on day 9, 23, ... + + Returns + ----------- + an inference object, often numpyro.infer.MCMC object used to infer parameters. + This can be used to print summaries, pass along covariance matrices, or query posterier distributions + """ + self.inference_algo.run( + rng_key=PRNGKey(self.config.INFERENCE_PRNGKEY), + obs_hosps=obs_hosps, + obs_hosps_days=obs_hosps_days, + obs_sero_lmean=obs_sero_lmean, + obs_sero_lsd=obs_sero_lsd, + obs_sero_days=obs_sero_days, + obs_var_prop=obs_var_prop, + obs_var_days=obs_var_days, + obs_var_sd=obs_var_sd, + ) + self.inference_algo.print_summary() + self.infer_complete = True + self.inference_timesteps = max(obs_hosps_days) + 1 + return self.inference_algo + + def _solve_runner( + self, parameters: dict, tf: int, runner: MechanisticRunner + ) -> Solution: + """an Overridden version of AbstractParameters._solve_runner() designed + specifically to rework the initial state to fit the new compartment + definitions before solving the transmission model. + + The fitting period uses 5 strains, while projection uses 7, so we must expand + those dimensions and reset the C compartment. + + Parameters + ---------- + parameters : dict + parameters object containing parameters required by the runner ODEs + tf : int + number of days to run the runner for + runner : MechanisticRunner + runner class designated with solving ODEs + + Returns + ------- + Solution + diffrax solution object returned from runner.run() + """ + initial_state = self.rework_initial_state(parameters["INITIAL_STATE"]) + + solution = runner.run( + initial_state, + args=parameters, + tf=tf, + ) + return solution + + def _get_predictions(self, parameters, solution): + """ + OVERRIDEN FUNCTION of MechanisticInferer._get_predictions() + + This overriden function will calculate hospitalizations differently, but also + will calculate variant proportions and serology of the predicted population. + + The rest of this comment block remains unchanged from the original function: + generates post-hoc predictions from solved timeseries in `Solution` and + parameters used to generate them within `parameters`. This will often be hospitalizations + but could be more than just that. + + Parameters + ---------- + parameters : dict + parameters object returned by `get_parameters()` possibly containing information about the + infection hospitalization ratio + solution : Solution + Solution object returned by `_solve_runner` or any call to `self.runner.run()` + containing compartment timeseries + + Returns + ------- + jax.Array or tuple[jax.Array] + one or more jax arrays representing the different post-hoc predictions generated from + `solution`. If fitting upon hospitalizations only, then a single jax.Array representing hospitalizations will be present. + + """ + # add 1 to idxs because we are stratified by time in the solution object + # sum down to just time x age bins + model_incidence = jnp.sum( + solution.ys[self.config.COMPARTMENT_IDX.C], + axis=( + self.config.I_AXIS_IDX.hist + 1, + self.config.I_AXIS_IDX.vax + 1, + self.config.I_AXIS_IDX.strain + 1, + ), + ) + # axis = 0 because we take diff across time + model_incidence = jnp.diff(model_incidence, axis=0) + # sample intrinsic infection hospitalization rate here + ihr_mult_prior_means = jnp.array([0.02, 0.05, 0.14]) + ihr_mult_prior_variances = ( + jnp.array( + [ + 9.36e-05, + 6.94e-05, + 0.00029, + ] + ) + / 4 + ) + + ihr_mult_prior_a = ( + ( + ihr_mult_prior_means + * (1 - ihr_mult_prior_means) + / ihr_mult_prior_variances + ) + - 1 + ) * ihr_mult_prior_means + ihr_mult_prior_b = ( + ( + ihr_mult_prior_means + * (1 - ihr_mult_prior_means) + / ihr_mult_prior_variances + ) + - 1 + ) * (1 - ihr_mult_prior_means) + + ihr_mult_0 = numpyro.sample( + "ihr_mult_0", Dist.Beta(ihr_mult_prior_a[0], ihr_mult_prior_b[0]) + ) + ihr_mult_1 = numpyro.sample( + "ihr_mult_1", Dist.Beta(ihr_mult_prior_a[1], ihr_mult_prior_b[1]) + ) + ihr_mult_2 = numpyro.sample( + "ihr_mult_2", Dist.Beta(ihr_mult_prior_a[2], ihr_mult_prior_b[2]) + ) + ihr_3 = numpyro.sample("ihr_3", Dist.Beta(40 * 10, 360 * 10)) + ihr = jnp.array([ihr_mult_0, ihr_mult_1, ihr_mult_2, 1]) * ihr_3 + + # sample ihr multiplier due to previous infection or vaccinations + ihr_immune_mult = numpyro.sample( + "ihr_immune_mult", Dist.Beta(100 * 6, 300 * 6) + ) + + # sample ihr multiplier due to JN1 (assuming JN1 has less severity) + # ihr_jn1_mult = numpyro.sample( + # "ihr_jn1_mult", Dist.Beta(400 * 4, 4 * 4) + # ) + ihr_jn1_mult = numpyro.deterministic("ihr_jn1_mult", 0.95) + + # calculate modelled hospitalizations based on the ihrs + # add 1 to wane because we have time dimension prepended + model_incidence = jnp.diff( + solution.ys[self.config.COMPARTMENT_IDX.C], + axis=0, + ) + + model_incidence_no_exposures_non_jn1 = jnp.sum( + model_incidence[:, :, 0, 0, :, :4], axis=(-1, -2) + ) + model_incidence_no_exposures_jn1 = jnp.sum( + model_incidence[:, :, 0, 0, :, 4:], axis=(-1, -2) + ) + model_incidence_wbooster_non_jn1 = jnp.sum( + model_incidence[:, :, :, 3, :, :4], axis=(-1, -2, -3) + ) + model_incidence_wbooster_jn1 = jnp.sum( + model_incidence[:, :, :, 3, :, 4:], axis=(-1, -2, -3) + ) + model_incidence_all_non_jn1 = jnp.sum( + model_incidence[:, :, :, :, :, :4], axis=(2, 3, 4, 5) + ) + model_incidence_all_jn1 = jnp.sum( + model_incidence[:, :, :, :, :, 4:], axis=(2, 3, 4, 5) + ) + model_incidence_wexp_no_booster_non_jn1 = ( + model_incidence_all_non_jn1 + - model_incidence_no_exposures_non_jn1 + - model_incidence_wbooster_non_jn1 + ) + model_incidence_wexp_no_booster_jn1 = ( + model_incidence_all_jn1 + - model_incidence_no_exposures_jn1 + - model_incidence_wbooster_jn1 + ) + + # calculate weekly model hospitalizations with the two IHRs we created + # TODO, should we average every 7 days or just pick every day from obs_metrics + booster_ihr_reduction = getattr( + self.config, "BOOSTER_IHR_REDUCTION", 0.0 + ) + model_hosps = ( + model_incidence_no_exposures_non_jn1 * ihr + + model_incidence_no_exposures_jn1 * ihr * ihr_jn1_mult + + model_incidence_wbooster_non_jn1 + * ihr + * ihr_immune_mult + * (1 - booster_ihr_reduction) + + model_incidence_wbooster_jn1 + * ihr + * ihr_immune_mult + * ihr_jn1_mult + * (1 - booster_ihr_reduction) + + model_incidence_wexp_no_booster_non_jn1 * ihr * ihr_immune_mult + + model_incidence_wexp_no_booster_jn1 + * ihr + * ihr_immune_mult + * ihr_jn1_mult + ) + ## Seroprevalence + never_infected = jnp.sum( + solution.ys[self.config.COMPARTMENT_IDX.S][:, :, 0, :, :], + axis=(2, 3), + ) + sim_seroprevalence = 1 - never_infected / parameters["POPULATION"] + strain_incidence = jnp.sum( + solution.ys[self.config.COMPARTMENT_IDX.C], + axis=( + self.config.C_AXIS_IDX.age + 1, + self.config.C_AXIS_IDX.hist + 1, + self.config.C_AXIS_IDX.vax + 1, + self.config.C_AXIS_IDX.wane + 1, + ), + ) + + return (model_hosps, sim_seroprevalence, strain_incidence) + + def run_simulation(self, tf): + """An override of the mechanistic_inferer.run_simulation() function in order to + save all needed timelines + + + Parameters + ---------- + tf : _type_ + _description_ + """ + parameters = self.get_parameters() + solution = self._solve_runner(parameters, tf, self.runner) + ( + hospitalizations, + sim_seroprevalence, + strain_incidence, + ) = self._get_predictions(parameters, solution) + return { + "solution": solution, + "hospitalizations": hospitalizations, + "sim_seroprevalence": sim_seroprevalence, + "strain_incidence": strain_incidence, + "parameters": parameters, + } + + def sample_strain_x_intro_time(self, offset): + """ + Samples a value of the strain X intro time based on the lags between introduction times of the fitted strains + """ + # use numpyro.sample to read in the posterior values of the intro times + # we only introduced strains BA2BA5 - KP, so go through those (index exclusive so ends at X) + past_introduction_times = jnp.array( + [ + numpyro.sample( + "INTRODUCTION_TIMES_%s" % idx, + numpyro.distributions.Normal(), + ) + for idx, _ in enumerate( + self.config.STRAIN_IDX._member_names_[ + self.config.STRAIN_IDX.BA2BA5 : self.config.STRAIN_IDX.KP + ] + ) + ] + ) + # get the day of year of intro for XBB1 and JN1 (fall variants) + init_yday = 42 # TODO: this is hardcoded + xbb1_yday = (init_yday + past_introduction_times[1]) % 365 + jn1_yday = (init_yday + past_introduction_times[3]) % 365 + mean_yday = (xbb1_yday + jn1_yday) / 2 + sd_yday = 14 # fix at 14 days sd + yday_dist = Dist.Normal(loc=mean_yday, scale=sd_yday) + + strain_x_intro_yday = numpyro.sample("INTRO_YDAY_X", yday_dist) + # ensure that intro_yday is non-negative + strain_x_intro_yday = jnp.max(jnp.array([0, strain_x_intro_yday])) + # if yday smaller than offset, add 365 to it + strain_x_intro_time_raw = strain_x_intro_yday - offset + strain_x_intro_time_raw = jnp.where( + strain_x_intro_time_raw < 0, + 365 + strain_x_intro_time_raw, + strain_x_intro_time_raw, + ) + + # max with zero to avoid negatives + strain_x_intro_time = numpyro.deterministic( + "INTRODUCTION_TIME_X", strain_x_intro_time_raw + ) + return strain_x_intro_time + + @partial(jax.jit, static_argnums=(0)) + def vaccination_rate(self, t): + vaccine_offset = getattr(self.config, "ZERO_VACCINE_DAY", 0.0) + vaccine_rate_mult = getattr( + self.config, "VACCINATION_RATE_MULTIPLIER", 1.0 + ) + vaccine_rate_mult = jnp.where( + t < vaccine_offset, 0.0, vaccine_rate_mult + ) + t_offset = jnp.where(t < vaccine_offset, 0.0, t - vaccine_offset) + return vaccine_rate_mult * super().vaccination_rate(t_offset) + + def generate_downstream_parameters(self, parameters: dict) -> dict: + """ + OVERRIDEN FUNCTION to add INITIAL_STATE into the parameters list + as we are loading posterior particles from a previous fit when projecting. + This function will only be called in the context of a loaded posterior particle + within `self.load_posterior_particle()` + + takes an existing parameters object and attempts to generate a number of + downstream dependent parameters, based on the values contained within `parameters`. + + Raises RuntimeError if a downstream parameter + does not find the necessary values it needs within `parameters` + + Example + --------- + if the parameter `Y = 1/X` then X must be defined within `parameters` and + we call `parameters["Y"] = 1 / parameters["X"]` + + Parameters + ---------- + parameters : dict + parameters dictionary generated by `self._get_upstream_parameters()` + containing static or sampled values on which downstream parameters may depend + + Returns + ------- + dict + an appended onto version of `parameters` with additional downstream parameters added. + """ + try: + # create parameters based on other possibly sampled parameters + beta = parameters["STRAIN_R0s"] / parameters["INFECTIOUS_PERIOD"] + gamma = 1 / parameters["INFECTIOUS_PERIOD"] + sigma = 1 / parameters["EXPOSED_TO_INFECTIOUS"] + # last waning time is zero since last compartment does not wane + # catch a division by zero error here. + waning_rates = np.array( + [ + 1 / waning_time if waning_time > 0 else 0 + for waning_time in parameters["WANING_TIMES"] + ] + ) + + # numpyro needs to believe it is sampling from something in order for the override to work + def fake_sampler(): + return numpyro.distributions.Normal() + + # fake_sampler will be overriden in runtime by the `final_timestep` values of the posteriors + # final_timestep refers to the final state of the system after fitting, aka day 0 of projection + # for more information on why this works look into MechanisticInferer.load_posterior_particle() + parameters["INITIAL_STATE"] = tuple( + [ + jnp.array( + numpyro.sample("final_timestep_s", fake_sampler()) + ), + jnp.array( + numpyro.sample("final_timestep_e", fake_sampler()) + ), + jnp.array( + numpyro.sample("final_timestep_i", fake_sampler()) + ), + jnp.array( + numpyro.sample("final_timestep_c", fake_sampler()) + ), + ] + ) + # inserting this line here, needs to be in freeze_params to avoid memory leaks + # of multiple chains modifying self.POPULATION + parameters["POPULATION"] = self.retrieve_population_counts( + parameters["INITIAL_STATE"] + ) + # if we are introducing a strain, the INTRODUCTION_TIMES array will be non-empty + # and if we specifically want to sample strain_x intro time, we set that flag to True + if ( + self.config.SAMPLE_STRAIN_X_INTRO_TIME + and parameters["INTRODUCTION_TIMES"][1] + ): + init_yday = parameters["INIT_DATE"].timetuple().tm_yday + strain_x_intro_time = self.sample_strain_x_intro_time( + offset=init_yday + ) + # reset our intro time to the sampled lag distribution intro time for strain X + parameters["INTRODUCTION_TIMES"][1] = strain_x_intro_time + # allows the ODEs to just pass time as a parameter, makes them look cleaner + external_i_function_prefilled = jax.tree_util.Partial( + self.external_i, + introduction_times=parameters["INTRODUCTION_TIMES"], + introduction_scales=parameters["INTRODUCTION_SCALES"], + introduction_pcts=parameters["INTRODUCTION_PCTS"], + population=parameters["POPULATION"], + ) + # # pre-calculate the minimum value of the seasonality curves + seasonality_function_prefilled = jax.tree_util.Partial( + self.seasonality, + seasonality_amplitude=parameters["SEASONALITY_AMPLITUDE"], + seasonality_second_wave=parameters["SEASONALITY_SECOND_WAVE"], + seasonality_shift=parameters["SEASONALITY_SHIFT"], + ) + # add final parameters, if your model expects added parameters, add them here + parameters = dict( + parameters, + **{ + "BETA": beta, + "SIGMA": sigma, + "GAMMA": gamma, + "WANING_RATES": waning_rates, + "EXTERNAL_I": external_i_function_prefilled, + "VACCINATION_RATES": self.vaccination_rate, + "BETA_COEF": self.beta_coef, + "SEASONAL_VACCINATION_RESET": self.seasonal_vaccination_reset, + "SEASONALITY": seasonality_function_prefilled, + "POPULATION": parameters["POPULATION"], + } + ) + # new code for projections in particular + + parameters["STRAIN_R0s"] = jnp.array( + [ + parameters["STRAIN_R0s"][0], + parameters["STRAIN_R0s"][1], + parameters["STRAIN_R0s"][2], + parameters["STRAIN_R0s"][3], + numpyro.deterministic( + "STRAIN_R0s_4", parameters["STRAIN_R0s"][2] + ), + parameters["R0_MULTIPLIER"] + * numpyro.deterministic( + "STRAIN_R0s_5", parameters["STRAIN_R0s"][3] + ), + numpyro.deterministic( + "STRAIN_R0s_6", parameters["STRAIN_R0s"][2] + ), + ] + ) + parameters["BETA"] = ( + parameters["STRAIN_R0s"] / parameters["INFECTIOUS_PERIOD"] + ) + avg_2strain_interaction = ( + parameters["STRAIN_INTERACTIONS"][5, 3] + + parameters["STRAIN_INTERACTIONS"][4, 2] + ) / 2 + avg_1strain_interaction = ( + parameters["STRAIN_INTERACTIONS"][5, 4] + + parameters["STRAIN_INTERACTIONS"][4, 3] + ) / 2 + immune_escape_64 = (1 - avg_2strain_interaction) * parameters[ + "STRAIN_INTERACTIONS" + ][6, 4] + immune_escape_65 = (1 - avg_1strain_interaction) * parameters[ + "STRAIN_INTERACTIONS" + ][6, 5] + parameters["STRAIN_INTERACTIONS"] = ( + parameters["STRAIN_INTERACTIONS"] + .at[6, 4] + .set(1 - immune_escape_64) + ) + parameters["STRAIN_INTERACTIONS"] = ( + parameters["STRAIN_INTERACTIONS"] + .at[6, 5] + .set(1 - immune_escape_65) + ) + # re-create the CROSSIMMUNITY_MATRIX since we modified the STRAIN_INTERACTIONS matrix + parameters[ + "CROSSIMMUNITY_MATRIX" + ] = utils.strain_interaction_to_cross_immunity2( + parameters["NUM_STRAINS"], + parameters["STRAIN_INTERACTIONS"], + ) + except KeyError as e: + err_txt = """Attempted to create a downstream parameter but was unable to find + the required upstream values within `parameters` this is likely because it was not included + within self.UPSTREAM_PARAMETERS and was therefore not collected + before generating the downstream params""" + raise RuntimeError(err_txt) from e + + return parameters + + def get_parameters(self): + """ + Overriding the get_parameters() method to work with an undefined initial state + because projections only define their initial state by sampling the `final_timestep` parameter + we need self.POPULATION to only be evaluated after self.INITIAL_STATE has been pulled from the posteriors + + While this appears to be the exact same as super().get_parameters() + the function it calls is overriden, and thus does very different things. + + it specifically difers because it does not use `self.POPULATION` since + it does not exist yet, it must be set only AFTER initial state exists. + """ + parameters = self._get_upstream_parameters() + parameters = self.generate_downstream_parameters(parameters) + return parameters + + def retrieve_population_counts(self, initial_state): + return np.sum( # sum together S+E+I compartments + np.array( + [ + np.sum( + compartment, + axis=( + self.config.S_AXIS_IDX.hist, + self.config.S_AXIS_IDX.vax, + self.config.S_AXIS_IDX.wane, + ), + ) # sum over all but age bin axis + for compartment in initial_state[ + : self.config.COMPARTMENT_IDX.C + ] # avoid summing the book-keeping C compartment + ] + ), + axis=(0), # sum across compartments, keep age bins + ) + + def rework_initial_state(self, initial_state): + """ + Take the original `initial_state` which is (4, 7, 3, 4) -> (4, 8, 4, 4) and add an + additional strain to the infection history and an additional vax tier and the infected by dimensions for E+I + """ + s_new = jnp.pad( + initial_state[0], [(0, 0), (0, 2), (0, 1), (0, 0)], mode="constant" + ) + e_new = jnp.pad( + initial_state[1], [(0, 0), (0, 2), (0, 1), (0, 2)], mode="constant" + ) + i_new = jnp.pad( + initial_state[2], [(0, 0), (0, 2), (0, 1), (0, 2)], mode="constant" + ) + c_new_shape = list(s_new.shape) + c_new_shape.append(i_new.shape[3]) + c_new = jnp.zeros(tuple(c_new_shape)) + initial_state = ( + s_new, + e_new, + i_new, + c_new, + ) + return initial_state + + def scale_initial_infections( + self, scale_factor, INITIAL_STATE + ) -> SEIC_Compartments: + """ + overriden version that does not use self.INITIAL_STATE + a function which modifies returns a modified version of + self.INITIAL_STATE scaling the number of initial infections by `scale_factor`. + + Preserves the ratio of the Exposed/Infectious compartment population sizes. + Does not modified self.INITIAL_STATE, returns a copy. + + Parameters + ---------- + scale_factor: float + a multiplier value >=0.0. + `scale_factor` < 1 reduces number of initial infections, + `scale_factor` == 1.0 leaves initial infections unchanged, + `scale_factor` > 1 increases number of initial infections. + + Returns + --------- + A copy of INITIAL_INFECTIONS with each compartment being scaled according to `scale_factor` + """ + pop_counts_by_compartment = jnp.array( + [ + jnp.sum(compartment) + for compartment in INITIAL_STATE[ + : self.config.COMPARTMENT_IDX.C + ] + ] + ) + initial_infections = ( + pop_counts_by_compartment[self.config.COMPARTMENT_IDX.E] + + pop_counts_by_compartment[self.config.COMPARTMENT_IDX.I] + ) + initial_susceptibles = pop_counts_by_compartment[ + self.config.COMPARTMENT_IDX.S + ] + # total_pop_size = initial_susceptibles + initial_infections + new_infections_size = scale_factor * initial_infections + # negative if scale_factor < 1.0 + gained_infections = new_infections_size - initial_infections + scale_factor_susceptible_compartment = 1 - ( + gained_infections / initial_susceptibles + ) + # multiplying E and I by the same scale_factor preserves their relative ratio + scale_factors = [ + scale_factor_susceptible_compartment, + scale_factor, + scale_factor, + 1.0, # for the C compartment, unchanged. + ] + # scale each compartment and return + initial_state = tuple( + [ + compartment * factor + for compartment, factor in zip(INITIAL_STATE, scale_factors) + ] + ) + return initial_state + + @partial(jax.jit, static_argnums=(0)) + def external_i( + self, + t, + introduction_times: jax.Array, + introduction_scales: jax.Array, + introduction_pcts: jax.Array, + population, + ) -> jax.Array: + """ + Given some time t, returns jnp.array of shape self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I] representing external infected persons + interacting with the population. it does so by calling some function f_s(t) for each strain s. + + MUST BE CONTINUOUS AND DIFFERENTIABLE FOR ALL TIMES t. + + The stratafication of the external population is decided by the introduced strains, which are defined by + 3 parallel lists of the time they peak (`introduction_times`), + the number of external infected individuals introduced as a % of the tracked population (`introduction_pcts`) + and how quickly or slowly those individuals contact the tracked population (`introduction_scales`) + + Parameters + ---------- + `t`: float as Traced + current time in the model, due to the just-in-time nature of Jax this float value may be contained within a + traced array of shape () and size 1. Thus no explicit comparison should be done on "t". + + `introduction_times`: list[int] as Traced + a list representing the times at which external strains should be introduced, in days, after t=0 of the model + This list is ordered inversely to self.config.STRAIN_R0s. If 2 external strains are defined, the two + values in `introduction_times` will refer to the last 2 STRAIN_R0s, not the first two. + + `introduction_scales`: list[float] as Traced + a list representing the standard deviation of the curve that external strains are introduced with, in days + This list is ordered inversely to self.config.STRAIN_R0s. If 2 external strains are defined, the two + values in `introduction_times` will refer to the last 2 STRAIN_R0s, not the first two. + + `introduction_pcts`: list[float] as Traced + a list representing the proportion of each age bin in self.POPULATION[self.config.INTRODUCTION_AGE_MASK] + that will be exposed to the introduced strain over the entire course of the introduction. + This list is ordered inversely to self.config.STRAIN_R0s. If 2 external strains are defined, the two + values in `introduction_times` will refer to the last 2 STRAIN_R0s, not the first two. + + Returns + ----------- + external_i_compartment: jax.Array + jnp.array(shape=(self.INITIAL_STATE[self.config.COMPARTMENT_IDX.I].shape)) of external individuals to the system + interacting with susceptibles within the system, used to impact force of infection. + """ + + # define a function that returns 0 for non-introduced strains + def zero_function(_): + return 0 + + external_i_distributions = [ + zero_function for _ in range(self.config.NUM_STRAINS) + ] + introduction_percentage_by_strain = [0] * self.config.NUM_STRAINS + for introduced_strain_idx, ( + introduced_time, + introduction_scale, + introduction_perc, + ) in enumerate( + zip(introduction_times, introduction_scales, introduction_pcts) + ): + # earlier introduced strains earlier will be placed closer to historical strains (0 and 1) + dist_idx = ( + self.config.NUM_STRAINS + - self.config.NUM_INTRODUCED_STRAINS + + introduced_strain_idx + ) + # use a normal PDF with std dv + external_i_distributions[dist_idx] = partial( + pdf, loc=introduced_time, scale=introduction_scale + ) + introduction_percentage_by_strain[dist_idx] = introduction_perc + # with our external_i_distributions set up, now we can execute them on `t` + # set up our return value + external_i_compartment = jnp.zeros( + ( + self.config.NUM_AGE_GROUPS, + self.config.NUM_STRAINS + 1, + self.config.MAX_VACCINATION_COUNT + 1, + self.config.NUM_STRAINS, + ) + ) + introduction_age_mask = jnp.where( + jnp.array(self.config.INTRODUCTION_AGE_MASK), + 1, + 0, + ) + for strain in self.config.STRAIN_IDX: + external_i_distribution = external_i_distributions[strain] + introduction_perc = introduction_percentage_by_strain[strain] + external_i_compartment = external_i_compartment.at[ + introduction_age_mask, 0, 0, strain + ].set( + external_i_distribution(t) + * introduction_perc + * population[introduction_age_mask] + ) + return external_i_compartment + + def likelihood( + self, + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_lsd, + obs_sero_days, + obs_var_prop, + obs_var_days, + obs_var_sd, + tf, + ): + """ + overridden likelihood that takes as input weekly hosp data starting from self.config.INIT_DATE + + Parameters + ---------- + obs_hosps: jnp.ndarray: weekly hosp incidence values from NHSN + obs_hosps_days: list[int] the sim day on which each obs_hosps value is measured. + for example obs_hosps[0] = 0 = self.config.INIT_DATE + obs_sero_lmean: jnp.ndarray: observed seroprevalence in logit scale + obs_sero_lsd: jnp.ndarray: standard deviation of logit seroprevalence (use this to + control the magnitude of uncertainty / weightage of fitting) + obs_sero_days: list[int] the sim day on which each obs_sero value is measured. + e.g., [9, 23, ...] meaning that we have data on day 9, 23, ... + """ + dct = self.run_simulation(tf) + # filtering predicted hospitalizations into observed days, then updating sampling algo + model_hosps = dct["hospitalizations"] + # obs_hosps_days = [6, 13, 20, ....] + # Incidence from day 0, 1, 2, ..., 6 goes to first bin, day 7 - 13 goes to second bin... + # break model_hosps into chunks of intervals and aggregate them + # first, find out which interval goes to which days + hosps_interval_ind = jnp.searchsorted( + jnp.array(obs_hosps_days), jnp.arange(max(obs_hosps_days) + 1) + ) + # for observed, multiply number by number of days within an interval + obs_hosps_interval = ( + obs_hosps + * jnp.bincount(hosps_interval_ind, length=len(obs_hosps_days))[ + :, None + ] + ) + # for simulated, aggregate by index + sim_hosps_interval = jnp.array( + [ + jnp.bincount(hosps_interval_ind, m, length=len(obs_hosps_days)) + for m in model_hosps.T + ] + ).T + mask_incidence = ~jnp.isnan(obs_hosps_interval) + with numpyro.handlers.mask(mask=mask_incidence): + numpyro.sample( + "incidence", + Dist.Poisson(sim_hosps_interval), + obs=obs_hosps_interval, + ) + + # filtering predicted sero prev into observed days, then updating sampling algo + sim_seroprevalence = dct["sim_seroprevalence"] + # filter to just observed days + sim_seroprevalence = sim_seroprevalence[obs_sero_days, ...] + sim_lseroprevalence = jnp.log( + sim_seroprevalence / (1 - sim_seroprevalence) + ) # logit seroprevalence + + mask_sero = ~jnp.isnan(obs_sero_lmean) + with numpyro.handlers.mask(mask=mask_sero): + numpyro.sample( + "lseroprevalence", + Dist.Normal(sim_lseroprevalence, obs_sero_lsd), + obs=obs_sero_lmean, + ) + + # filtering predicted var proportions into observed days, then updating sampling algo + + strain_incidence = dct["strain_incidence"] + strain_incidence = jnp.diff(strain_incidence, axis=0)[ + : (max(obs_var_days) + 1) + ] + var_interval_ind = jnp.searchsorted( + jnp.array(obs_var_days), jnp.arange(max(obs_var_days) + 1) + ) + strain_incidence_interval = jnp.array( + [ + jnp.bincount(var_interval_ind, m, length=len(obs_var_days)) + for m in strain_incidence.T + ] + ).T + sim_var_prop = jnp.array( + [incd / jnp.sum(incd) for incd in strain_incidence_interval] + ) + sim_var_sd = jnp.ones(sim_var_prop.shape) * obs_var_sd + + numpyro.sample( + "variant_proportion", + Dist.Normal(sim_var_prop, sim_var_sd), + obs=obs_var_prop, + ) diff --git a/exp/projections_ihr2_2404_2507/postprocess_states_0.py b/exp/projections_ihr2_2404_2507/postprocess_states_0.py new file mode 100644 index 00000000..36e8706c --- /dev/null +++ b/exp/projections_ihr2_2404_2507/postprocess_states_0.py @@ -0,0 +1,82 @@ +# %% +import argparse +import os + +import pandas as pd +from tqdm import tqdm + +OUTPUT_PATH = "/output/projections_ihr2_2404_2507/" +parser = argparse.ArgumentParser() + + +def save_collated_timeseries(output_path, job_id): + job_path = os.path.join(output_path, job_id) + states = [ + d + for d in os.listdir(job_path) + if os.path.isdir(os.path.join(job_path, d)) + ] + state_dfs = [] + for st in tqdm(states, desc="processing states "): + scenario_dfs = [] + state_path = os.path.join(job_path, st) + scens = [ + d + for d in os.listdir(state_path) + if os.path.isdir(os.path.join(state_path, d)) + ] + + for sc in scens: + csv_path = os.path.join( + state_path, sc, "azure_visualizer_timeline.csv" + ) + if os.path.exists( + os.path.join(state_path, sc, "azure_visualizer_timeline.csv") + ): + df = pd.read_csv( + csv_path, + usecols=[ + "chain_particle", + "date", + "pred_hosp_0_17", + "pred_hosp_18_49", + "pred_hosp_50_64", + "pred_hosp_65+", + "vaccination_0_17", + "vaccination_18_49", + "vaccination_50_64", + "vaccination_65+", + "JN1_strain_proportion", + "KP_strain_proportion", + "X_strain_proportion", + ], + ) + + df["state"] = st + df["scenario"] = sc + scenario_dfs.append(df) + + state_df = pd.concat(scenario_dfs) + state_df.to_csv( + os.path.join(state_path, "all_scens_projections.csv"), index=False + ) + state_dfs.append(state_df) + all_states_df = pd.concat(state_dfs) + all_states_df.to_csv( + os.path.join(job_path, "all_states_projections.csv"), index=False + ) + + +parser = argparse.ArgumentParser(description="Experiment Azure Launcher") +parser.add_argument( + "--job_id", + "-j", + type=str, + help="job ID of the azure job, must be unique", + required=True, +) + +if __name__ == "__main__": + args = parser.parse_args() + job_id: str = args.job_id + save_collated_timeseries(OUTPUT_PATH, job_id) diff --git a/exp/projections_ihr2_2404_2507/run_task.py b/exp/projections_ihr2_2404_2507/run_task.py new file mode 100644 index 00000000..9ba28df8 --- /dev/null +++ b/exp/projections_ihr2_2404_2507/run_task.py @@ -0,0 +1,142 @@ +# ruff: noqa: E402 +import argparse +import json +import os +import shutil +import sys + +import jax +import numpy as np + +# adding things to path since in a docker container pathing gets changed +sys.path.append("/app/") +sys.path.append("/input/exp/fifty_state_5strain_2202_2404/") +print(os.getcwd()) +from resp_ode import MechanisticRunner +from resp_ode.model_odes.seip_model_flatten_immune_hist import seip_ode +from src.mechanistic_azure.abstract_azure_runner import AbstractAzureRunner + +# sys.path.append(".") +# sys.path.append(os.getcwd()) +from .inferer_projection import ProjectionParameters + +jax.config.update("jax_enable_x64", True) + +# will be multiplied by number of chains to get total number of posteriors +NUM_SAMPLES_PER_STATE_PER_SCENARIO = 25 +HISTORICAL_FIT_PATH = ( + "/output/fifty_state_5strain_2202_2404/SMH_5strains_240807_v16" +) +EXP_ID = "projections_ihr2_2404_2507" + + +class ProjectionRunner(AbstractAzureRunner): + # __init__ already implemented by the abstract case + def __init__(self, azure_output_dir): + super().__init__(azure_output_dir) + + def process_state(self, state, jobid=None, local_run=False, scenario=None): + projection_period_num_days = 434 + posteriors_path = os.path.join( + HISTORICAL_FIT_PATH, + state, + ) + checkpoint_path = os.path.join(posteriors_path, "checkpoint.json") + assert os.path.exists(checkpoint_path), ( + "checkpoint does not exist for this state %s" % state + ) + posteriors = json.load(open(checkpoint_path, "r")) + # the final states of the fitting period are saved within posteriors + # step 1: define your paths, now in the input + state_config_path = os.path.join( + f"/input/exp/{EXP_ID}/{jobid}/states", + state, + ) + if local_run: + state_config_path = os.path.join( + f"/input/exp/{EXP_ID}/states", + state, + ) + assert os.path.exists(state_config_path), ( + "the state path %s does not exist" % state_config_path + ) + print("Running the following state: " + state + "\n") + # global_config include definitions such as age bin bounds and strain definitions + # Any value or data structure that needs context to be interpretted is here. + GLOBAL_CONFIG_PATH = os.path.join( + state_config_path, "config_global.json" + ) + # a config file that defines the scenario being run + INFERER_CONFIG_PATH = os.path.join( + state_config_path, "%s.json" % scenario + ) + + cg_path = os.path.join( + self.azure_output_dir, "config_global_used.json" + ) + ci_path = os.path.join( + self.azure_output_dir, "config_inferer_used.json" + ) + # if you are hitting this block, this means you are either running locally + # or you rerunning a jobid which is bad practice + if os.path.exists(cg_path): + print( + "You are overriding an existing job's outputs, " + "this is bad practice and can destroy reproducibility. Proceed with caution" + ) + os.remove(cg_path) + if os.path.exists(ci_path): + os.remove(ci_path) + shutil.copy(GLOBAL_CONFIG_PATH, cg_path) + shutil.copy(INFERER_CONFIG_PATH, ci_path) + + # sets up the initial conditions, initializer.get_initial_state() passed to runner + runner = MechanisticRunner(seip_ode) + inferer = ProjectionParameters( + GLOBAL_CONFIG_PATH, INFERER_CONFIG_PATH, runner + ) + # self.save_inference_posteriors(inferer) + np.random.seed(4326) + self.save_inference_timelines( + inferer, + particles_saved=NUM_SAMPLES_PER_STATE_PER_SCENARIO, + external_particle=posteriors, + tf=projection_period_num_days, + ) + + +parser = argparse.ArgumentParser() +parser.add_argument( + "-s", + "--state", + type=str, + help="directory for the state to run, resembles USPS code of the state", +) + +parser.add_argument( + "-j", "--jobid", type=str, help="job-id of the state being run on Azure" +) +parser.add_argument( + "-sc", "--scenario", type=str, help="scenario being run on Azure" +) +parser.add_argument( + "-l", "--local", action="store_true", help="scenario being run on Azure" +) + +if __name__ == "__main__": + args = parser.parse_args() + jobid: str = args.jobid + state: str = args.state + scenario: str = args.scenario + local: bool = args.local + # we are going to be rerouting stdout and stderror to files in our output blob + # stdout = sys.stdout + # stderror = sys.stderr + save_path = "/output/%s/%s/%s/%s/" % ( + EXP_ID, + jobid, + state, + scenario, + ) + runner = ProjectionRunner(save_path) + runner.process_state(state, jobid, local_run=local, scenario=scenario) diff --git a/exp/projections_ihr2_2404_2507/scenarios_setup.py b/exp/projections_ihr2_2404_2507/scenarios_setup.py new file mode 100644 index 00000000..488b024c --- /dev/null +++ b/exp/projections_ihr2_2404_2507/scenarios_setup.py @@ -0,0 +1,313 @@ +""" +A basic script that sets up experiments by taking a directory, creating a bunch of +state-specific folders within it and populating each folder with read-only copies +of the configuration files specified. This way each state can be run in parallel +where a single state only needs its config files and can store output within its +state-specific folder, to be collected later. +""" + +import argparse +import copy +import json +import os + +import numpy as np +import pandas as pd + +EXP_ID = "projections_ihr2_2404_2507" +EXP_FOLDER = f"exp/{EXP_ID}" +CONFIG_MOLDS = [ + f"exp/{EXP_ID}/template_configs/config_global.json", + f"exp/{EXP_ID}/template_configs/scenario_template.json", +] + +SCEN_CSV = f"exp/{EXP_ID}/scenarios.csv" + + +def create_state_subdirectories(dir, state_names): + """ + function to create an experiment directory `dir` and then create + subfolders for each Postal Abreviation in `state_names`. + Will not override if `dir` or `dir/state_names[i]` already exists + + Parameters + ------------ + `dir`: str + relative or absolute directory path of the experiment, + for which subdirectories per state will be created under it. + + `state_names`: list[str] + list of USPS postal codes per state involved in the experiment, will create subfolders of `dir` + with each code. + + Returns + ------------ + None + """ + # Create the main directory if it does not exist + if not os.path.exists(dir): + os.makedirs(dir) + + # Create subdirectories for each state inside the "states" folder + for state in state_names: + state_dir = os.path.join(dir, "states", state) + if not os.path.exists(state_dir): + os.makedirs(state_dir) + + +def create_multiple_scenarios_configs(state_config, state_abb, subdir_path): + """ + Create json for scenarios based on SCEN_CSV. This would include: + 1. intro_time, when does the new strain get introduced + 2. vaccination, 0 or 1, whether this is recommended/implemented + 3. immune_escape, scalar multiplier (in %) indicating if the immune escape of the + next major strain is the same (e.g. 1), higher (e.g., 1.2) or lower (e.g., 0.8) + 4. vaccine_efficacy, vaccine efficacy multiplier (in %), which is used to calculate VE of + seasonal booster via 1 - (1 - VE_2dose) * (1 - vaccine efficacy) + """ + print(state_abb) + # intro_time_lookup = { + # "aug": [32], + # "sep": [63], + # "oct": [93], + # "nov": [124], + # "dec": [154], + # "none": [], + # "sample": [1], + # } + df = pd.read_csv(SCEN_CSV) + df = df[["f" in x for x in df["id"]]] + dicts = [] + paths = [] + for index, row in df.iterrows(): + vs = row["vaccination"] + ir = row["ihr_reduction"] + ie = row["immune_escape"] + ve = row["vaccine_efficacy"] + kpit = row["kp_intro_time"] + mult = row["r0_percentage"] + + st_config = copy.deepcopy(state_config) + # vax0 no booster (except children), vax1 with booster across all age + if vs == 0: + st_config[ + "VACCINATION_MODEL_DATA" + ] = "/input/data/vaccination-data/2024_06_30_to_2025_06_28_vax0/" + else: + st_config[ + "VACCINATION_MODEL_DATA" + ] = "/input/data/vaccination-data/2024_06_30_to_2025_06_28_vax1/" + st_config["VACCINATION_RATE_MULTIPLIER"] = vs / 100 + + # intro time convert from month to day in model + st_config["INTRODUCTION_TIMES"][0] = kpit + st_config["INTRODUCTION_TIMES"][1] = 1 + st_config["SAMPLE_STRAIN_X_INTRO_TIME"] = "sample" + st_config["BOOSTER_IHR_REDUCTION"] = ir / 100 + # inject multiplier that get used in inferer_projection + st_config["STRAIN_INTERACTIONS"][6][4] = ie / 100 + st_config["STRAIN_INTERACTIONS"][6][5] = ie / 100 + # VE is calculated based on VE_2dose + vaccine_efficacy = np.array(st_config["VACCINE_EFF_MATRIX"]) + vaccine_efficacy[:, 3] = ve / 100 + # vaccine_efficacy[:, 3] = 1 - (1 - vaccine_efficacy[:, 2]) * ( + # 1 - ve / 100 + # ) + st_config["VACCINE_EFF_MATRIX"] = vaccine_efficacy.tolist() + st_config["R0_MULTIPLIER"] = mult / 100.0 + dicts.append(st_config) + new_config_file_path = os.path.join( + subdir_path, + f"vs{str(vs)}_ir{ir}_ie{str(ie)}_ve{str(ve)}_mult{str(mult)}_kpit{str(kpit)}.json", + ) + paths.append(new_config_file_path) + + return dicts, paths + + +def populate_config_files(dir, configs): + """ + scans an experiment directory `dir` opening each folder, and copying over read-only versions + of each json file in `configs`, modifying the "REGIONS" key to match the postal code. + Modifies the `POP_SIZE` variable to match the states population according to the census. + Modifies the `INITIAL_INFECTIONS` variable to equal the same % of the population as in the mold config. + eg: 2% of TOTAL_POP in mold config applied to each state's individual `POP_SIZE` + + will raise an error if a subdirectory of `dir` is not a postal code able to be looked up. + + Parameters + ------------ + `dir`: str + relative or absolute directory path of the experiment, + contains subdirectories created by `create_state_subdirectories` + + `configs`: list[str] + list of paths to each config mold, these config molds will be copied into each subdirectory as read-only + they will have their "REGIONS" key changed to resemble the state the subdirectory is modeling. + + Returns + ------------ + None + """ + dir = os.path.join(dir, "states") + for subdir in os.listdir(dir): + subdir_path = os.path.join(dir, subdir) + json_dicts = [] + output_paths = [] + if os.path.isdir(subdir_path): + state_name = code_to_state(subdir) + state_pop = code_to_pop(state_name) + + for config_file_path in configs: + # Read the original JSON file + with open(config_file_path) as f: + state_config = json.load(f) + + # Change the "REGION" key to state name + state_config["REGIONS"] = [state_name] + + if "POP_SIZE" in state_config.keys(): + # havent changed yet, so old value still in `state_config` + mold_pop_size = state_config["POP_SIZE"] + # match the same % of the population as in the mold config to new state POP_SIZE + if "INITIAL_INFECTIONS" in state_config.keys(): + mold_initial_inf = state_config["INITIAL_INFECTIONS"] + # state_pop * (% of infections in the mold config) + # round to 3 sig figs, convert to int + state_config["INITIAL_INFECTIONS"] = int( + float( + "%.3g" + % ( + state_pop + * (mold_initial_inf / mold_pop_size) + ) + ) + ) + # round pop sizes 3 sig figs then convert to int + state_config["POP_SIZE"] = int(float("%.3g" % state_pop)) + + if "scenario" in config_file_path: + dicts, paths = create_multiple_scenarios_configs( + state_config, subdir, subdir_path + ) + json_dicts.extend(dicts) + output_paths.extend(paths) + else: + json_dicts.append(state_config) + new_config_file_path = os.path.join( + subdir_path, os.path.basename(config_file_path) + ) + output_paths.append(new_config_file_path) + + # Create a new read-only copy of the JSON file with modified data + for d, p in zip(json_dicts, output_paths): + # if the config file already exists, we remove and override it. + if os.path.exists(p): + # change back from readonly so it can be deleted, otherwise get PermissionError + os.chmod(p, 0o777) + os.remove(p) + + with open(p, "w") as f: + json.dump(d, f, indent=4) + + # Set the new file permissions to read-only + os.chmod(p, 0o444) + + +def code_to_state(code): + """ + basic function to read in an postal code and return associated state name + + Parameters + ---------- + code: str + usps code the state + + Returns + ---------- + str/KeyError: state name, or KeyError if code does not point to a state or isnt an str + """ + state_info = state_names[state_names["stusps"] == code] + if len(state_info) == 1: + return state_info["stname"].iloc[0] + else: + raise KeyError("Unknown code %s" % code) + + +def code_to_pop(state_name): + """ + basic function to read in an postal code and return associated state name + + Parameters + ---------- + state_name: str + state name + + Returns + ---------- + str/KeyError: state population, or KeyError if invalid state name + """ + state_pop = pops[pops["STNAME"] == state_name] + if len(state_pop) == 1: + return state_pop["POPULATION"].iloc[0] + else: + raise KeyError("Unknown fips %s" % state_name) + + +def get_all_codes(): + return list(state_names["stusps"]) + + +# script takes arguments to specify the experiment being created. +parser = argparse.ArgumentParser() + +# list of fips codes +parser.add_argument( + "-s", + "--states", + type=str, + required=True, + nargs="+", + help="space separated list of str representing USPS postal code of each state", +) +# the molds of configs to bring into each state sub-dir +parser.add_argument( + "-m", + "--config_molds", + type=str, + required=False, + nargs="+", + default=CONFIG_MOLDS, + help="space separated paths to the config molds, defaults to some in /config", +) + +if __name__ == "__main__": + state_names = pd.read_csv("data/fips_to_name.csv") + pops = pd.read_csv("data/demographic-data/CenPop2020_Mean_ST.csv") + # adding a USA row with the sum of all state pops + usa_pop_row = pd.DataFrame( + [ + [ + "US", + "United States", + sum(pops["POPULATION"]), + None, + None, + ] + ], + columns=pops.columns, + ) + pops = pd.concat([pops, usa_pop_row], ignore_index=True) + args = parser.parse_args() + states = args.states + if "all" in states: + states = get_all_codes() + states.remove("US") + states.remove("DC") + + config_molds = args.config_molds + create_state_subdirectories(EXP_FOLDER, states) + populate_config_files(EXP_FOLDER, config_molds) + print( + "Created and populated state level directories with read-only copies of the config files" + ) diff --git a/src/resp_ode/model_odes/seip_model_flatten_immune_hist.py b/src/resp_ode/model_odes/seip_model_flatten_immune_hist.py new file mode 100644 index 00000000..51cb4845 --- /dev/null +++ b/src/resp_ode/model_odes/seip_model_flatten_immune_hist.py @@ -0,0 +1,148 @@ +import jax +import jax.numpy as jnp +from jaxtyping import ArrayLike, PyTree + +from resp_ode.utils import Parameters, get_foi_suscept + + +def seip_ode(state: PyTree, t: ArrayLike, parameters: dict): + s, e, i, c = state + if any([not isinstance(compartment, jax.Array) for compartment in state]): + raise TypeError( + "Please pass jax.numpy.array instead of np.array to ODEs" + ) + # spoof the dict into a class so we can use `p.` notation instead of dicts + p = Parameters(parameters) + ds, de, di, dc = ( + jnp.zeros(s.shape), + jnp.zeros(e.shape), + jnp.zeros(i.shape), + jnp.zeros(c.shape), + ) + beta_coef = p.BETA_COEF(t) + seasonality_coef = p.SEASONALITY(t) + # CALCULATING SUCCESSFULL INFECTIONS OF (partially) SUSCEPTIBLE INDIVIDUALS + # including externally infected individuals to introduce new strains + force_of_infection = ( + ( + p.BETA + * beta_coef + * seasonality_coef + * jnp.einsum( + "ab,bijk->ak", + p.CONTACT_MATRIX, + i + p.EXTERNAL_I(t), + ) + ).transpose() + / p.POPULATION + ).transpose() # (NUM_AGE_GROUPS, strain) + + foi_suscept = jnp.array(get_foi_suscept(p, force_of_infection)) + # we are vmaping this for loop. We select the force of infection + # for each strain, and calculated the number of susceptibles it exposes + # we sum over wane bin since `e` has no waning bin. + # OLD FOR LOOP FOR INTERPRETABILITY + # for strain in range(p.NUM_STRAINS): + # exposed_s = s * foi_suscept[strain] + # de = de.at[:, :, :, strain].add( + # jnp.sum(exposed_s, axis=-1) + # ) + # ds = jnp.add(ds, -exposed_s) + exposed_s = jnp.moveaxis( + jax.vmap( + lambda s, foi_suscept: s * foi_suscept, + in_axes=(None, 0), + )(s, foi_suscept), + 0, + -1, + ) # returns shape (s.shape..., p.NUM_STRAINS) + # s has waning as last dimension, e has infected strain as last dim + # the last two dimensions of `exposed_s` are `wane` and `strain` + # so lets sum over them to get the expected shape for each + de = de + jnp.sum(exposed_s, axis=-2) # remove wane so matches e.shape + ds = ds - jnp.sum(exposed_s, axis=-1) # remove strain so matches s.shape + dc = exposed_s # at this point we only have infections in de, so we add to cumulative + # e and i shape remain same, just multiplying by a constant. + de_to_i = p.SIGMA * e # exposure -> infectious + di_to_w0 = p.GAMMA * i # infectious -> new_immune_state + di = jnp.add(de_to_i, -di_to_w0) + de = jnp.add(de, -de_to_i) + + # go through all combinations of immune history and exposing strain + # calculate new immune history after recovery, place them there. + # THIS CODE REPLACES THE FOLLOWING FOR LOOP + # for strain, immune_state in product( + # range(p.NUM_STRAINS), range(2**p.NUM_STRAINS) + # ): + # new_state = new_immune_state(immune_state, strain, p.NUM_STRAINS) + # # recovered i->w0 transfer from `immune_state` -> `new_state` due to recovery from `strain` + # ds = ds.at[:, new_state, :, 0].add( + # di_to_w0[:, immune_state, :, strain] + # ) + for strain in range(p.NUM_STRAINS): + ds = ds.at[:, strain + 1, :, 0].add( + jnp.sum(di_to_w0[:, :, :, strain], axis=1) + ) + + # lets measure our waned + vax rates + # last w group doesn't wane but WANING_RATES enforces a 0 at the end + waning_array = jnp.zeros(s.shape).at[:, :, :].add(p.WANING_RATES) + s_waned = waning_array * s + ds = ds.at[:, :, :, 1:].add(s_waned[:, :, :, :-1]) + ds = ds.at[:, :, :, :-1].add(-s_waned[:, :, :, :-1]) + + # slice across age, strain, and wane. vaccination updates the vax column and also moves all to w0. + # ex: diagonal movement from 1 shot in 4th waning compartment to 2 shots 0 waning compartment s[:, 0, 1, 3] -> s[:, 0, 2, 0] + # input vaccination rate is per entire population, need to update to per compartments first + vax_rates = p.VACCINATION_RATES(t) + vax_totals = vax_rates * p.POPULATION[:, None] + vax_status_counts = jnp.sum( + s, axis=(1, 3) + ) # Sum over immune hist and waning to get count per age and vax status + updated_vax_rates = vax_totals / (vax_status_counts + 1e-10) + updated_vax_rates = jnp.minimum(updated_vax_rates, 0.95) + # updated_vax_rates = jnp.where( + # updated_vax_rates > 1.0, + # jnp.ones(updated_vax_rates.shape), + # updated_vax_rates, + # ) # prevent moving more people out than the compartments have + + # Assuming that people who received 2 or more doses wouldn't get additional booster too soon + # i.e., when they were still within the first waning compartment + vax_counts = s * updated_vax_rates[:, jnp.newaxis, :, jnp.newaxis] + vax_counts = vax_counts.at[:, :, p.MAX_VACCINATION_COUNT, 0].set(0) + vax_gained = jnp.sum(vax_counts, axis=(-1)) + ds = ds.at[:, :, p.MAX_VACCINATION_COUNT, 0].add( + vax_gained[:, :, p.MAX_VACCINATION_COUNT] + ) + ds = ds.at[:, :, 1 : (p.MAX_VACCINATION_COUNT) + 1, 0].add( + vax_gained[:, :, 0 : p.MAX_VACCINATION_COUNT] + ) + ds = ds - vax_counts + + # if we are not implementing seasonal vaccination p.SEASONAL_VACCINATION_RESET(t) = 0 forall t + # and you can safely ignore this section + seasonal_vaccination_outflow = p.SEASONAL_VACCINATION_RESET(t) + # flow seasonal_vaccination_outflow% of seasonal vaxers back to max ordinal tier + ds = ds.at[:, :, p.MAX_VACCINATION_COUNT - 1, :].add( + seasonal_vaccination_outflow * s[:, :, p.MAX_VACCINATION_COUNT, :] + ) + # remove these people from the seasonal vaccination tier + ds = ds.at[:, :, p.MAX_VACCINATION_COUNT, :].add( + -seasonal_vaccination_outflow * s[:, :, p.MAX_VACCINATION_COUNT, :] + ) + # do the same process for e and i compartments + de = de.at[:, :, p.MAX_VACCINATION_COUNT - 1, :].add( + seasonal_vaccination_outflow * e[:, :, p.MAX_VACCINATION_COUNT, :] + ) + de = de.at[:, :, p.MAX_VACCINATION_COUNT, :].add( + -seasonal_vaccination_outflow * e[:, :, p.MAX_VACCINATION_COUNT, :] + ) + di = di.at[:, :, p.MAX_VACCINATION_COUNT - 1, :].add( + seasonal_vaccination_outflow * i[:, :, p.MAX_VACCINATION_COUNT, :] + ) + di = di.at[:, :, p.MAX_VACCINATION_COUNT, :].add( + -seasonal_vaccination_outflow * i[:, :, p.MAX_VACCINATION_COUNT, :] + ) + + return (ds, de, di, dc) diff --git a/src/resp_ode/utils.py b/src/resp_ode/utils.py index 7d2900c3..3d9f0129 100644 --- a/src/resp_ode/utils.py +++ b/src/resp_ode/utils.py @@ -1367,6 +1367,10 @@ def get_foi_suscept(p, force_of_infection): Calculate the force of infections experienced by the susceptibles, _after_ factoring their immunity. + Calculates the minimum homologous immunity of individuals based on a + p.MIN_HOMOLOGOUS_IMMUNITY * the homologous immunity. Meaning if an individual has 70% + homologous immunity, their minimum floor after waning can be p.MIN_HOMOLOGOUS_IMMUNITY * 0.7 + Parameters ---------- `p` : Parameters @@ -1396,11 +1400,13 @@ def get_foi_suscept(p, force_of_infection): 1 - vax_efficacy_strain, ) # renormalize the waning curve to have minimum of `final_immunity` after full waning - # and maximum of `initial_immunity` right after recovery - final_immunity = jnp.zeros(shape=initial_immunity.shape) - final_immunity = final_immunity.at[ - all_immune_states_with(strain, p.NUM_STRAINS), : - ].set(p.MIN_HOMOLOGOUS_IMMUNITY) + # and maximum of `initial_immunity` right after rec + final_immunity = ( + jnp.ones(shape=initial_immunity.shape) + * crossimmunity_matrix[:, jnp.newaxis] + * p.MIN_HOMOLOGOUS_IMMUNITY + ) + waned_immunity_baseline = jnp.einsum( "jk,l", initial_immunity, From e12f4de531530ae8c684be9976cddb3e61f52b14 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Fri, 27 Sep 2024 21:48:14 +0000 Subject: [PATCH 2/5] checkpoint projections now working and revamped --- Dockerfile | 4 +- .../experiment_launcher_azure.py | 108 ++++ .../inferer_projection.py | 42 +- exp/projections_ihr2_2404_2507/run_task.py | 8 +- .../template_configs/config_global.json | 57 ++ .../scenario_sampler0_template.json | 490 ++++++++++++++++++ .../template_configs/scenario_template.json | 489 +++++++++++++++++ 7 files changed, 1183 insertions(+), 15 deletions(-) create mode 100644 exp/projections_ihr2_2404_2507/experiment_launcher_azure.py create mode 100644 exp/projections_ihr2_2404_2507/template_configs/config_global.json create mode 100644 exp/projections_ihr2_2404_2507/template_configs/scenario_sampler0_template.json create mode 100644 exp/projections_ihr2_2404_2507/template_configs/scenario_template.json diff --git a/Dockerfile b/Dockerfile index 0d597d0c..2530d480 100644 --- a/Dockerfile +++ b/Dockerfile @@ -23,5 +23,5 @@ COPY ./src/ /app/src # turn off interaction since we cant type `yes` on the prompts in docker build RUN poetry install --no-interaction --no-ansi # we will upload the experiment itself into the cloud and refer to from /input -COPY ./mechanistic_azure/abstract_azure_runner.py /app/mechanistic_azure/abstract_azure_runner.py -COPY ./mechanistic_azure/azure_utilities.py /app/mechanistic_azure/azure_utilities.py +# COPY ./src/mechanistic_azure/abstract_azure_runner.py /app/src/mechanistic_azure/abstract_azure_runner.py +# COPY ./src/mechanistic_azure/azure_utilities.py /app/src/mechanistic_azure/azure_utilities.py diff --git a/exp/projections_ihr2_2404_2507/experiment_launcher_azure.py b/exp/projections_ihr2_2404_2507/experiment_launcher_azure.py new file mode 100644 index 00000000..a7dd053b --- /dev/null +++ b/exp/projections_ihr2_2404_2507/experiment_launcher_azure.py @@ -0,0 +1,108 @@ +""" +a script which is passed two flags, one to dictate the experiment folder on which to run +and another to specify the location of the state-specific runner script, which will perform analysis on a single state. +""" + +import argparse +import os + +from src.mechanistic_azure.azure_utilities import AzureExperimentLauncher + +# from cfa_azure.clients import AzureClient + + +class ScenarioLauncher(AzureExperimentLauncher): + """A class designed to launch jobs but specifically a number of scenarios + per state. + """ + + def launch_states(self, depend_on_task_ids: list[str] = None) -> list[str]: + """ + OVERIDDEN FUNCTION to launch scenarios for each state as well as the states themselves + Launches an Azure Batch job under `self.job_id`, + populating it with tasks for each subdirectory within your experiment's `states` directory + passing each state name to `run_task.py` with the -s flag and the job_id with the -j flag. + + Parameters + ---------- + depend_on_task_ids: list[str], optional + list of task ids on which each state depends on finishing to start themselves, defaults to None + Returns + ------- + list[str] + list of all tasks launched under `job_id` + """ + # command to run the job, if we have already launched states previously, dont recreate the job + if not self.job_launched: + self.azure_client.add_job(job_id=self.job_id) + self.job_launched = True + # upload the experiment folder so that the runner_path_docker & states_path_docker point to the correct places + # here we pass `location=self.experiment_path_blob` because we are not uploading from the docker container + # therefore we dont need the /input/ mount directory + self._upload_experiment_to_blob() + task_ids = [] + # add a task for each scenario json in state directory in the states folder of this experiment + for statedir in os.listdir(self.states_path_local): + statedir_path = os.path.join(self.states_path_local, statedir) + if os.path.isdir(statedir_path): + # add a task setting the runner onto each state + # we use the -s flag with the subdir name, + # since experiment directories are structured with USPS state codes as directory names + # also include the -j flag to specify the jobid + scenarios = [ + f.replace(".json", "") + for f in os.listdir(statedir_path) + if "config_global" not in f + ] + for sc in scenarios: + task_id = self.azure_client.add_task( + job_id=job_id, + docker_cmd="python %s -s %s -j %s -sc %s" + % (self.runner_path_docker, statedir, job_id, sc), + depends_on=depend_on_task_ids, + ) + task_ids += task_id + return task_ids + + +DOCKER_IMAGE_TAG = "arik-revamp-projections1" +# number of seconds of a full experiment run before timeout +# for `s` states to run and `n` nodes dedicated,`s/n` * runtime 1 state secs needed +TIMEOUT_MINS = 360 +EXPERIMENTS_DIRECTORY = "exp" +EXPERIMENT_NAME = "projections_ihr2_2404_2507" +SECRETS_PATH = "secrets/configuration_cfaazurebatchprd.toml" +# Parse command-line arguments +# specify job ID, cant already exist +parser = argparse.ArgumentParser(description="Experiment Azure Launcher") +parser.add_argument( + "--job_id", + type=str, + help="job ID of the azure job, must be unique", + required=True, +) + +args = parser.parse_args() +job_id: str = args.job_id +launcher = ScenarioLauncher( + EXPERIMENT_NAME, + job_id, + azure_config_toml=SECRETS_PATH, + experiment_directory=EXPERIMENTS_DIRECTORY, + docker_image_name=DOCKER_IMAGE_TAG, +) +launcher.set_resource_pool(pool_name="scenarios_4cpu_pool") +all_tasks_run = [] +# all experiments will be placed under the same jobid, +# subsequent experiments depend on prior ones to finish before starting +launcher.set_all_paths( + experiments_folder_name=EXPERIMENTS_DIRECTORY, + experiment_name=EXPERIMENT_NAME, +) +state_task_ids = launcher.launch_states(depend_on_task_ids=all_tasks_run) +print(state_task_ids) +postprocessing_tasks = launcher.launch_postprocess( + depend_on_task_ids=state_task_ids +) +all_tasks_run += state_task_ids + postprocessing_tasks +launcher.azure_client.monitor_job(job_id) diff --git a/exp/projections_ihr2_2404_2507/inferer_projection.py b/exp/projections_ihr2_2404_2507/inferer_projection.py index 360b7a17..4580a97b 100644 --- a/exp/projections_ihr2_2404_2507/inferer_projection.py +++ b/exp/projections_ihr2_2404_2507/inferer_projection.py @@ -11,12 +11,7 @@ from jax.scipy.stats.norm import pdf from numpyro.infer import MCMC -from resp_ode import ( - MechanisticInferer, - MechanisticRunner, - SEIC_Compartments, - utils, -) +from resp_ode import MechanisticInferer, MechanisticRunner, SEIC_Compartments from resp_ode.config import Config @@ -33,10 +28,12 @@ class ProjectionParameters(MechanisticInferer): "VACCINE_EFF_MATRIX", "BETA_TIMES", "STRAIN_R0s", - "INFECTIOUS_PERIOD" "EXPOSED_TO_INFECTIOUS", + "INFECTIOUS_PERIOD", + "EXPOSED_TO_INFECTIOUS", "INTRODUCTION_TIMES", "INTRODUCTION_SCALES", - "INTRODUCTION_PCTS" "INITIAL_INFECTIONS_SCALE", + "INTRODUCTION_PCTS", + "INITIAL_INFECTIONS_SCALE", "CONSTANT_STEP_SIZE", "SEASONALITY_AMPLITUDE", "SEASONALITY_SECOND_WAVE", @@ -574,7 +571,7 @@ def fake_sampler(): # re-create the CROSSIMMUNITY_MATRIX since we modified the STRAIN_INTERACTIONS matrix parameters[ "CROSSIMMUNITY_MATRIX" - ] = utils.strain_interaction_to_cross_immunity2( + ] = self.strain_interaction_to_cross_immunity( parameters["NUM_STRAINS"], parameters["STRAIN_INTERACTIONS"], ) @@ -587,6 +584,33 @@ def fake_sampler(): return parameters + def strain_interaction_to_cross_immunity( + self, num_strains: int, strain_interactions: np.ndarray + ) -> jax.Array: + """Because we are overriding the definitions of immune history, the model must + contain its own method for generating the cross immunity matrix, not relying on + utils.strain_interaction_to_cross_immunity() which has complex logic + to convert the strain interactions matrix to a cross immunity matrix + + Parameters + ---------- + num_strains : int + number of strains in the model + strain_interactions : jax.Array + the strain interactions matrix + + Returns + ------- + jax.Array + the cross immunity matrix which is similar to the strain + matrix but with an added 0 layer for no previous exposure + """ + cim = jnp.hstack( + (jnp.array([[0.0]] * num_strains), strain_interactions), + ) + + return cim + def get_parameters(self): """ Overriding the get_parameters() method to work with an undefined initial state diff --git a/exp/projections_ihr2_2404_2507/run_task.py b/exp/projections_ihr2_2404_2507/run_task.py index 9ba28df8..ede69eba 100644 --- a/exp/projections_ihr2_2404_2507/run_task.py +++ b/exp/projections_ihr2_2404_2507/run_task.py @@ -12,14 +12,14 @@ sys.path.append("/app/") sys.path.append("/input/exp/fifty_state_5strain_2202_2404/") print(os.getcwd()) +# sys.path.append(".") +# sys.path.append(os.getcwd()) +from inferer_projection import ProjectionParameters + from resp_ode import MechanisticRunner from resp_ode.model_odes.seip_model_flatten_immune_hist import seip_ode from src.mechanistic_azure.abstract_azure_runner import AbstractAzureRunner -# sys.path.append(".") -# sys.path.append(os.getcwd()) -from .inferer_projection import ProjectionParameters - jax.config.update("jax_enable_x64", True) # will be multiplied by number of chains to get total number of posteriors diff --git a/exp/projections_ihr2_2404_2507/template_configs/config_global.json b/exp/projections_ihr2_2404_2507/template_configs/config_global.json new file mode 100644 index 00000000..1de49d2e --- /dev/null +++ b/exp/projections_ihr2_2404_2507/template_configs/config_global.json @@ -0,0 +1,57 @@ +{ + "DEMOGRAPHIC_DATA_PATH": "/input/data/demographic-data/", + "REGIONS": [ + "Alabama" + ], + "INIT_DATE": "2024-04-20", + "VACCINATION_SEASON_CHANGE": "2025-07-01", + "AGE_LIMITS": [ + 0, + 18, + 50, + 65 + ], + "NUM_STRAINS": 7, + "MAX_VACCINATION_COUNT": 3, + "NUM_WANING_COMPARTMENTS": 4, + "WANING_TIMES": [ + 70, + 70, + 70, + 0 + ], + "COMPARTMENT_IDX": [ + "S", + "E", + "I", + "C" + ], + "S_AXIS_IDX": [ + "age", + "hist", + "vax", + "wane" + ], + "I_AXIS_IDX": [ + "age", + "hist", + "vax", + "strain" + ], + "C_AXIS_IDX": [ + "age", + "hist", + "vax", + "wane", + "strain" + ], + "STRAIN_IDX": [ + "omicron", + "BA2BA5", + "XBB1", + "XBB2", + "JN1", + "KP", + "X" + ] +} diff --git a/exp/projections_ihr2_2404_2507/template_configs/scenario_sampler0_template.json b/exp/projections_ihr2_2404_2507/template_configs/scenario_sampler0_template.json new file mode 100644 index 00000000..6f0cb526 --- /dev/null +++ b/exp/projections_ihr2_2404_2507/template_configs/scenario_sampler0_template.json @@ -0,0 +1,490 @@ +{ + "SCENARIO_NAME": "test covid run for testing suite", + "CONTACT_MATRIX_PATH": "/input/data/demographic-data/contact_matrices", + "SAVE_PATH": "/output/", + "HOSP_PATH": "/input/data/hospitalization-data", + "SERO_PATH": "/input/data/serological-data/fitting-2022/", + "VAR_PATH": "/input/data/variant-data/fitting-2022/", + "VACCINATION_MODEL_DATA": "/input/data/vaccination-data/2024_06_30_to_2025_06_28_vax88/", + "SEASONAL_VACCINATION": false, + "ZERO_VACCINE_DAY": 70, + "SAMPLE_STRAIN_X_INTRO_TIME": true, + "R0_MULTIPLIER": 0, + "STRAIN_R0s": [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 1.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 35, + "concentration0": 65 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 30, + "concentration0": 70 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 30, + "concentration0": 70 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0 + ], + "INFECTIOUS_PERIOD": 7.0, + "EXPOSED_TO_INFECTIOUS": 3.6, + "INITIAL_INFECTIONS_SCALE": { + "distribution": "TruncatedNormal", + "params": { + "loc": 1.0, + "scale": 0.1, + "low": 0.5, + "high": 2.0 + } + }, + "INTRODUCTION_TIMES": [ + 88, + 88 + ], + "CONSTANT_STEP_SIZE": 0, + "INTRODUCTION_PCTS": [ + 0.03, + 0.03 + ], + "INTRODUCTION_SCALES": [ + 13, + 13 + ], + "INTRODUCTION_AGE_MASK": [ + false, + true, + false, + false + ], + "MIN_HOMOLOGOUS_IMMUNITY": { + "distribution": "Beta", + "params": { + "concentration1": 500, + "concentration0": 1500 + } + }, + "WANING_PROTECTIONS": [ + 1.0, + 1.0, + 1.0, + 0.0 + ], + "STRAIN_INTERACTIONS": [ + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 40, + "concentration0": 60 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 60, + "concentration0": 40 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 80, + "concentration0": 20 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + 0.25, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 80, + "concentration0": 20 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + 0.15, + 0.25, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 70, + "concentration0": 30 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0 + ], + [ + 0.05, + 0.15, + 0.25, + 0.70, + 0.95, + 1.0, + 1.0 + ], + [ + 0.0, + 0.05, + 0.15, + 0.25, + 0.888, + 0.888, + 1.0 + ] + ], + "VACCINE_EFF_MATRIX": [ + [ + 0, + 0.35, + 0.70, + 0.88 + ], + [ + 0, + 0.30, + 0.60, + 0.88 + ], + [ + 0, + 0.25, + 0.50, + 0.88 + ], + [ + 0, + 0.20, + 0.40, + 0.88 + ], + [ + 0, + 0.095, + 0.19, + 0.88 + ], + [ + 0, + 0.095, + 0.19, + 0.88 + ], + [ + 0, + 0.095, + 0.19, + 0.88 + ] + ], + "BETA_TIMES": [ + 0.0, 70.0 + ], + "BETA_COEFICIENTS": [ + 1.0, 1.0 + ], + "SEASONALITY_AMPLITUDE": { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 60, + "concentration0": 40 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.0, + "scale": 0.1, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + "SEASONALITY_SECOND_WAVE": { + "distribution": "Beta", + "params": { + "concentration1": 100, + "concentration0": 100 + } + }, + "SEASONALITY_SHIFT": { + "distribution": "TruncatedNormal", + "params": { + "loc": -10, + "scale": 1, + "low": -25, + "high": 5 + } + }, + "INFERENCE_PRNGKEY": 8675311, + "INFERENCE_NUM_WARMUP": 1000, + "INFERENCE_NUM_SAMPLES": 1000, + "INFERENCE_NUM_CHAINS": 4, + "INFERENCE_PROGRESS_BAR": true, + "MAX_TREE_DEPTH": 9, + "MODEL_RAND_SEED": 8675318 +} diff --git a/exp/projections_ihr2_2404_2507/template_configs/scenario_template.json b/exp/projections_ihr2_2404_2507/template_configs/scenario_template.json new file mode 100644 index 00000000..cb387327 --- /dev/null +++ b/exp/projections_ihr2_2404_2507/template_configs/scenario_template.json @@ -0,0 +1,489 @@ +{ + "SCENARIO_NAME": "test covid run for testing suite", + "CONTACT_MATRIX_PATH": "/input/data/demographic-data/contact_matrices", + "SAVE_PATH": "/output/", + "HOSP_PATH": "/input/data/hospitalization-data", + "SERO_PATH": "/input/data/serological-data/fitting-2022/", + "VAR_PATH": "/input/data/variant-data/fitting-2022/", + "VACCINATION_MODEL_DATA": "/input/data/vaccination-data/2024_06_30_to_2025_06_28_vax88/", + "SEASONAL_VACCINATION": false, + "ZERO_VACCINE_DAY": 70, + "SAMPLE_STRAIN_X_INTRO_TIME": true, + "STRAIN_R0s": [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 1.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 35, + "concentration0": 65 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 30, + "concentration0": 70 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 30, + "concentration0": 70 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0 + ], + "INFECTIOUS_PERIOD": 7.0, + "EXPOSED_TO_INFECTIOUS": 3.6, + "INITIAL_INFECTIONS_SCALE": { + "distribution": "TruncatedNormal", + "params": { + "loc": 1.0, + "scale": 0.1, + "low": 0.5, + "high": 2.0 + } + }, + "INTRODUCTION_TIMES": [ + 88, + 88 + ], + "CONSTANT_STEP_SIZE": 0, + "INTRODUCTION_PCTS": [ + 0.03, + 0.03 + ], + "INTRODUCTION_SCALES": [ + 13, + 13 + ], + "INTRODUCTION_AGE_MASK": [ + false, + true, + false, + false + ], + "MIN_HOMOLOGOUS_IMMUNITY": { + "distribution": "Beta", + "params": { + "concentration1": 500, + "concentration0": 1500 + } + }, + "WANING_PROTECTIONS": [ + 1.0, + 1.0, + 1.0, + 0.0 + ], + "STRAIN_INTERACTIONS": [ + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 40, + "concentration0": 60 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 60, + "concentration0": 40 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 80, + "concentration0": 20 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + 0.25, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 80, + "concentration0": 20 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + 0.15, + 0.25, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 70, + "concentration0": 30 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0 + ], + [ + 0.05, + 0.15, + 0.25, + 0.70, + 0.95, + 1.0, + 1.0 + ], + [ + 0.0, + 0.05, + 0.15, + 0.25, + 0.888, + 0.888, + 1.0 + ] + ], + "VACCINE_EFF_MATRIX": [ + [ + 0, + 0.35, + 0.70, + 0.88 + ], + [ + 0, + 0.30, + 0.60, + 0.88 + ], + [ + 0, + 0.25, + 0.50, + 0.88 + ], + [ + 0, + 0.20, + 0.40, + 0.88 + ], + [ + 0, + 0.095, + 0.19, + 0.88 + ], + [ + 0, + 0.095, + 0.19, + 0.88 + ], + [ + 0, + 0.095, + 0.19, + 0.88 + ] + ], + "BETA_TIMES": [ + 0.0, 70.0 + ], + "BETA_COEFICIENTS": [ + 1.0, 1.0 + ], + "SEASONALITY_AMPLITUDE": { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 60, + "concentration0": 40 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.0, + "scale": 0.1, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + "SEASONALITY_SECOND_WAVE": { + "distribution": "Beta", + "params": { + "concentration1": 100, + "concentration0": 100 + } + }, + "SEASONALITY_SHIFT": { + "distribution": "TruncatedNormal", + "params": { + "loc": -10, + "scale": 1, + "low": -25, + "high": 5 + } + }, + "INFERENCE_PRNGKEY": 8675311, + "INFERENCE_NUM_WARMUP": 1000, + "INFERENCE_NUM_SAMPLES": 1000, + "INFERENCE_NUM_CHAINS": 4, + "INFERENCE_PROGRESS_BAR": true, + "MAX_TREE_DEPTH": 9, + "MODEL_RAND_SEED": 8675318 +} From 640ecae17864cd9e19a5298721589583fd3764ec Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Mon, 30 Sep 2024 21:34:19 +0000 Subject: [PATCH 3/5] checkpoint, fitting done, waiting to test via run when resources not needed --- .../fall_virus_inferer.py | 489 ++++++++++++++++++ .../postprocess_script_0.py | 290 +++++++++++ .../postprocess_script_1.py | 72 +++ .../run_task.py | 289 +++++++++++ .../template_configs/config_global.json | 55 ++ .../template_configs/config_inferer.json | 477 +++++++++++++++++ .../template_configs/config_initializer.json | 50 ++ .../template_configs/scenario_template.json | 6 +- 8 files changed, 1726 insertions(+), 2 deletions(-) create mode 100644 exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py create mode 100644 exp/fifty_state_season2_5strain_2202_2404/postprocess_script_0.py create mode 100644 exp/fifty_state_season2_5strain_2202_2404/postprocess_script_1.py create mode 100644 exp/fifty_state_season2_5strain_2202_2404/run_task.py create mode 100644 exp/fifty_state_season2_5strain_2202_2404/template_configs/config_global.json create mode 100644 exp/fifty_state_season2_5strain_2202_2404/template_configs/config_inferer.json create mode 100644 exp/fifty_state_season2_5strain_2202_2404/template_configs/config_initializer.json diff --git a/exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py b/exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py new file mode 100644 index 00000000..37fdfb8a --- /dev/null +++ b/exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py @@ -0,0 +1,489 @@ +import os + +import jax.numpy as jnp +import numpyro +import numpyro.distributions as Dist +from jax.random import PRNGKey +from jax.typing import ArrayLike + +from resp_ode import MechanisticInferer + + +class FallVirusInferer(MechanisticInferer): + UPSTREAM_PARAMETERS = [ + "INIT_DATE", + "CONTACT_MATRIX", + "POPULATION", + "NUM_STRAINS", + "NUM_AGE_GROUPS", + "NUM_WANING_COMPARTMENTS", + "WANING_PROTECTIONS", + "MAX_VACCINATION_COUNT", + "STRAIN_INTERACTIONS", + "VACCINE_EFF_MATRIX", + "BETA_TIMES", + "STRAIN_R0s", + "INFECTIOUS_PERIOD", + "EXPOSED_TO_INFECTIOUS", + "INTRODUCTION_TIMES", + "INTRODUCTION_SCALES", + "INTRODUCTION_PCTS", + "INITIAL_INFECTIONS_SCALE", + "CONSTANT_STEP_SIZE", + "SEASONALITY_AMPLITUDE", + "SEASONALITY_SECOND_WAVE", + "SEASONALITY_DOMINANT_WAVE_DAY", + "SEASONALITY_SECOND_WAVE_DAY", + "SEASONALITY_SHIFT", + "MIN_HOMOLOGOUS_IMMUNITY", + ] + + def load_vaccination_model(self): + """ + an overridden version of the vaccine model so we can load + state-specific vaccination splines using the REGIONS parameter + """ + vax_spline_filename = "spline_fits_%s.csv" % ( + self.config.REGIONS[0].lower().replace(" ", "_") + ) + vax_spline_path = os.path.join( + self.config.VACCINATION_MODEL_DATA, vax_spline_filename + ) + self.config.VACCINATION_MODEL_DATA = vax_spline_path + super().load_vaccination_model() + + def infer( + self, + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_lsd, + obs_sero_days, + obs_var_prop, + obs_var_days, + obs_var_sd, + ): + """ + OVERRIDEN TO ADD MORE DATA STREAMS TO COMPARE AGAINST + Infer parameters given priors inside of self.config, returns an inference_algo object with posterior distributions for each sampled parameter. + + + Parameters + ---------- + obs_hosps: jnp.ndarray: weekly hosp incidence values from NHSN + obs_hosps_days: list[int] the sim day on which each obs_hosps value is measured. + for example obs_hosps[0] = 0 = self.config.INIT_DATE + obs_sero_lmean: jnp.ndarray: observed seroprevalence in logit scale + obs_sero_lsd: jnp.ndarray: standard deviation of logit seroprevalence (use this to + control the magnitude of uncertainty / weightage of fitting) + obs_sero_days: list[int] the sim day on which each obs_sero value is measured. + e.g., [9, 23, ...] meaning that we have data on day 9, 23, ... + + Returns + ----------- + an inference object, often numpyro.infer.MCMC object used to infer parameters. + This can be used to print summaries, pass along covariance matrices, or query posterier distributions + """ + self.inference_algo.run( + rng_key=PRNGKey(self.config.INFERENCE_PRNGKEY), + obs_hosps=obs_hosps, + obs_hosps_days=obs_hosps_days, + obs_sero_lmean=obs_sero_lmean, + obs_sero_lsd=obs_sero_lsd, + obs_sero_days=obs_sero_days, + obs_var_prop=obs_var_prop, + obs_var_days=obs_var_days, + obs_var_sd=obs_var_sd, + ) + self.inference_algo.print_summary() + self.infer_complete = True + self.inference_timesteps = max(obs_hosps_days) + 1 + return self.inference_algo + + def generate_downstream_parameters(self, parameters: dict) -> dict: + """An Override for the generate downstream parameters in order + to lock in transmisibility of certain strains to dependent on + the values of others + The following is the existing docstring for this function: + + takes an existing parameters object and attempts to generate a number of + downstream dependent parameters, based on the values contained within `parameters`. + + Raises RuntimeError if a downstream parameter + does not find the necessary values it needs within `parameters` + + Example + --------- + if the parameter `Y = 1/X` then X must be defined within `parameters` and + we call `parameters["Y"] = 1 / parameters["X"]` + + Parameters + ---------- + parameters : dict + parameters dictionary generated by `self._get_upstream_parameters()` + containing static or sampled values on which downstream parameters may depend + + Returns + ------- + dict + an appended onto version of `parameters` with additional downstream parameters added. + + """ + parameters = super().generate_downstream_parameters(parameters) + parameters["STRAIN_R0s"] = jnp.array( + [ + parameters["STRAIN_R0s"][0], + parameters["STRAIN_R0s"][1], + parameters["STRAIN_R0s"][2], + parameters["STRAIN_R0s"][3], + numpyro.deterministic( + "STRAIN_R0s_4", parameters["STRAIN_R0s"][2] + ), + ] + ) + parameters["BETA"] = ( + parameters["STRAIN_R0s"] / parameters["INFECTIOUS_PERIOD"] + ) + + return parameters + + def seasonality( + self, + t: ArrayLike, + seasonality_amplitude: ArrayLike, + seasonality_second_wave: ArrayLike, + dominant_wave_day: ArrayLike, + second_wave_day: ArrayLike, + ) -> ArrayLike: + """A revamped seasonality function which allows for a dominant and second wave to vary in phase from one another + As opposed to the original seasonality function which required both waves to be exactly 6 months apart + the user may now specify two dates for the peaks of the seasonality, a dominant wave amplitude, + and a coefficient on the second waves peak. + Parameters + ---------- + t : ArrayLike + simulation day with t=0 being INIT_DATE + seasonality_amplitude : ArrayLike + amplitude size, with peak at 1+seasonality_amplitude and minimum at 1-seasonality_amplitude + seasonality_second_wave : ArrayLike + enforced 0 <= seasonality_second_wave <= 1.0 + adjusts how pronouced the second wave is, + with 1.0 being equally sized to dominant wave, and 0 being no second wave + dominant_wave_day : ArrayLike + date on which dominant seasonality coefficient peak + second_wave_day : ArrayLike + date on which second seasonality coefficient peaks + """ + + def g(t: ArrayLike, phi: ArrayLike, w: int): + return jnp.cos(2 * jnp.pi * (t - phi) / 730.0) ** w + + def f( + t: ArrayLike, + dominant_wave_day, + second_wave_day, + seasonality_second_wave: ArrayLike, + w: int = 20, + ): + # w is a positive even integer used to enforce cos to be always positive + # its magnitude increases the slope of the cosine curve, reaching its + # peak faster and falling back to its minimum quicker + return ( + -0.5 + + 1.0 * g(t, phi=dominant_wave_day, w=w) + + seasonality_second_wave * g(t, phi=second_wave_day, w=w) + ) + + shifted_dominant_wave_day = ( + dominant_wave_day - self.config.INIT_DATE.timetuple().tm_yday + ) + shifted_second_wave_day = ( + second_wave_day - self.config.INIT_DATE.timetuple().tm_yday + ) + + return ( + 1 + + f( + t, + dominant_wave_day=shifted_dominant_wave_day, + second_wave_day=shifted_second_wave_day, + seasonality_second_wave=seasonality_second_wave, + ) + * seasonality_amplitude + * 2 + ) + + def _get_predictions(self, parameters, solution): + """ + OVERRIDEN FUNCTION of MechanisticInferer._get_predictions() + + This overriden function will calculate hospitalizations differently, but also + will calculate variant proportions and serology of the predicted population. + + The rest of this comment block remains unchanged from the original function: + generates post-hoc predictions from solved timeseries in `Solution` and + parameters used to generate them within `parameters`. This will often be hospitalizations + but could be more than just that. + + Parameters + ---------- + parameters : dict + parameters object returned by `get_parameters()` possibly containing information about the + infection hospitalization ratio + solution : Solution + Solution object returned by `_solve_runner` or any call to `self.runner.run()` + containing compartment timeseries + + Returns + ------- + jax.Array or tuple[jax.Array] + one or more jax arrays representing the different post-hoc predictions generated from + `solution`. If fitting upon hospitalizations only, then a single jax.Array representing hospitalizations will be present. + """ + # save the final timestep of solution array for each compartment + numpyro.deterministic( + "final_timestep_s", solution.ys[self.config.COMPARTMENT_IDX.S][-1] + ) + numpyro.deterministic( + "final_timestep_e", solution.ys[self.config.COMPARTMENT_IDX.E][-1] + ) + numpyro.deterministic( + "final_timestep_i", solution.ys[self.config.COMPARTMENT_IDX.I][-1] + ) + numpyro.deterministic( + "final_timestep_c", solution.ys[self.config.COMPARTMENT_IDX.C][-1] + ) + # sample intrinsic infection hospitalization rate here + # m_i is ratio btw the average across all states of the median of ihr_0 and the average across all states + # of the median of ihr_3 produced from a previous fit + # v_i is defined similarly, but as the variance + # sample intrinsic infection hospitalization rate, + # where the concentrations are based on the function fit_new_beta + v_0, v_1, v_2 = ( + 9.361000583593154e-05, + 6.948228259019488e-05, + 0.0002923281265483212, + ) + + m_0, m_1, m_2 = ( + 0.020448716487218747, + 0.048698216511437936, + 0.1402618274806952, + ) + + ihr_mult_0 = numpyro.sample( + "ihr_mult_0", + Dist.Beta( + (m_0 * (1 - m_0) / v_0 - 1) * m_0, + (m_0 * (1 - m_0) / v_0 - 1) * (1 - m_0), + ), + ) + ihr_mult_1 = numpyro.sample( + "ihr_mult_1", + Dist.Beta( + (m_1 * (1 - m_1) / v_1 - 1) * m_1, + (m_1 * (1 - m_1) / v_1 - 1) * (1 - m_1), + ), + ) + ihr_mult_2 = numpyro.sample( + "ihr_mult_2", + Dist.Beta( + (m_2 * (1 - m_2) / v_2 - 1) * m_2, + (m_2 * (1 - m_2) / v_2 - 1) * (1 - m_2), + ), + ) + ihr_3 = numpyro.sample("ihr_3", Dist.Beta(60 * 20, 340 * 20)) + ihr = jnp.array( + [ihr_3 * ihr_mult_0, ihr_3 * ihr_mult_1, ihr_3 * ihr_mult_2, ihr_3] + ) + + # sample ihr multiplier due to previous infection or vaccinations + ihr_immune_mult = numpyro.sample( + "ihr_immune_mult", Dist.Beta(100 * 6, 300 * 6) + ) + + # sample ihr multiplier due to JN1 (assuming JN1 has less severity) + # ihr_jn1_mult = numpyro.sample("ihr_jn1_mult", Dist.Beta(100, 1)) + ihr_jn1_mult = numpyro.sample( + "ihr_jn1_mult", Dist.Beta(380 * 3, 20 * 3) + ) + + # calculate modelled hospitalizations based on the ihrs + # add 1 to wane because we have time dimension prepended + model_incidence = jnp.diff( + solution.ys[self.config.COMPARTMENT_IDX.C], + axis=0, + ) + model_incidence = jnp.sum(model_incidence, axis=4) + + model_incidence_no_exposures_non_jn1 = jnp.sum( + model_incidence[:, :, 0, 0, :4], axis=-1 + ) + model_incidence_no_exposures_jn1 = model_incidence[:, :, 0, 0, 4] + model_incidence_all_non_jn1 = jnp.sum( + model_incidence[:, :, :, :, :4], axis=(2, 3, 4) + ) + model_incidence_all_jn1 = jnp.sum( + model_incidence[:, :, :, :, 4], axis=(2, 3) + ) + model_incidence_w_exposures_non_jn1 = ( + model_incidence_all_non_jn1 - model_incidence_no_exposures_non_jn1 + ) + model_incidence_w_exposures_jn1 = ( + model_incidence_all_jn1 - model_incidence_no_exposures_jn1 + ) + + # calculate weekly model hospitalizations with the two IHRs we created + # TODO, should we average every 7 days or just pick every day from obs_metrics + model_hosps = ( + model_incidence_no_exposures_non_jn1 * ihr + + model_incidence_no_exposures_jn1 * ihr * ihr_jn1_mult + + model_incidence_w_exposures_non_jn1 * ihr * ihr_immune_mult + + model_incidence_w_exposures_jn1 + * ihr + * ihr_immune_mult + * ihr_jn1_mult + ) + ## Seroprevalence + never_infected = jnp.sum( + solution.ys[self.config.COMPARTMENT_IDX.S][:, :, 0, :, :], + axis=(2, 3), + ) + sim_seroprevalence = 1 - never_infected / self.config.POPULATION + # Strain prev + strain_incidence = jnp.sum( + solution.ys[self.config.COMPARTMENT_IDX.C], + axis=( + self.config.C_AXIS_IDX.age + 1, + self.config.C_AXIS_IDX.hist + 1, + self.config.C_AXIS_IDX.vax + 1, + self.config.C_AXIS_IDX.wane + 1, + ), + ) + return (model_hosps, sim_seroprevalence, strain_incidence) + + def run_simulation(self, tf): + """An override of the mechanistic_inferer.run_simulation() function in order to + save all needed timelines + + + Parameters + ---------- + tf : _type_ + _description_ + """ + parameters = self.get_parameters() + solution = self._solve_runner(parameters, tf, self.runner) + ( + hospitalizations, + sim_seroprevalence, + strain_incidence, + ) = self._get_predictions(parameters, solution) + return { + "solution": solution, + "hospitalizations": hospitalizations, + "sim_seroprevalence": sim_seroprevalence, + "strain_incidence": strain_incidence, + "parameters": parameters, + } + + def likelihood( + self, + obs_hosps=None, + obs_hosps_days=None, + obs_sero_lmean=None, + obs_sero_lsd=None, + obs_sero_days=None, + obs_var_prop=None, + obs_var_days=None, + obs_var_sd=None, + tf=None, + infer_mode=True, + ): + """ + overridden likelihood that takes as input weekly hosp data starting from self.config.INIT_DATE + + Parameters + ---------- + obs_hosps: jnp.ndarray: weekly hosp incidence values from NHSN + obs_hosps_days: list[int] the sim day on which each obs_hosps value is measured. + for example obs_hosps[0] = 0 = self.config.INIT_DATE + obs_sero_lmean: jnp.ndarray: observed seroprevalence in logit scale + obs_sero_lsd: jnp.ndarray: standard deviation of logit seroprevalence (use this to + control the magnitude of uncertainty / weightage of fitting) + obs_sero_days: list[int] the sim day on which each obs_sero value is measured. + e.g., [9, 23, ...] meaning that we have data on day 9, 23, ... + """ + dct = self.run_simulation(tf) + model_hosps = dct["hospitalizations"] + # add 1 to idxs because we are stratified by time in the solution object + # sum down to just time x age bins + # obs_hosps_days = [6, 13, 20, ....] + # Incidence from day 0, 1, 2, ..., 6 goes to first bin, day 7 - 13 goes to second bin... + # break model_hosps into chunks of intervals and aggregate them + # first, find out which interval goes to which days + hosps_interval_ind = jnp.searchsorted( + jnp.array(obs_hosps_days), jnp.arange(max(obs_hosps_days) + 1) + ) + # for observed, multiply number by number of days within an interval + obs_hosps_interval = ( + obs_hosps + * jnp.bincount(hosps_interval_ind, length=len(obs_hosps_days))[ + :, None + ] + ) + # for simulated, aggregate by index + sim_hosps_interval = jnp.array( + [ + jnp.bincount(hosps_interval_ind, m, length=len(obs_hosps_days)) + for m in model_hosps.T + ] + ).T + # x.shape = [650, 4] + # for x[0:7, :] -> y[0, :] + # y.shape = [65, 4] + mask_incidence = ~jnp.isnan(obs_hosps_interval) + with numpyro.handlers.mask(mask=mask_incidence): + numpyro.sample( + "incidence", + Dist.Poisson(sim_hosps_interval), + obs=obs_hosps_interval, + ) + sim_seroprevalence = dct["sim_seroprevalence"] + sim_seroprevalence = sim_seroprevalence[obs_sero_days, ...] + sim_lseroprevalence = jnp.log( + sim_seroprevalence / (1 - sim_seroprevalence) + ) # logit seroprevalence + + mask_sero = ~jnp.isnan(obs_sero_lmean) + with numpyro.handlers.mask(mask=mask_sero): + numpyro.sample( + "lseroprevalence", + Dist.Normal(sim_lseroprevalence, obs_sero_lsd), + obs=obs_sero_lmean, + ) + + ## Variant proportion + strain_incidence = dct["strain_incidence"] + strain_incidence = jnp.diff(strain_incidence, axis=0)[ + : (max(obs_var_days) + 1) + ] + var_interval_ind = jnp.searchsorted( + jnp.array(obs_var_days), jnp.arange(max(obs_var_days) + 1) + ) + strain_incidence_interval = jnp.array( + [ + jnp.bincount(var_interval_ind, m, length=len(obs_var_days)) + for m in strain_incidence.T + ] + ).T + sim_var_prop = jnp.array( + [incd / jnp.sum(incd) for incd in strain_incidence_interval] + ) + sim_var_sd = jnp.ones(sim_var_prop.shape) * obs_var_sd + + numpyro.sample( + "variant_proportion", + Dist.Normal(sim_var_prop, sim_var_sd), + obs=obs_var_prop, + ) diff --git a/exp/fifty_state_season2_5strain_2202_2404/postprocess_script_0.py b/exp/fifty_state_season2_5strain_2202_2404/postprocess_script_0.py new file mode 100644 index 00000000..85c6bd78 --- /dev/null +++ b/exp/fifty_state_season2_5strain_2202_2404/postprocess_script_0.py @@ -0,0 +1,290 @@ +# %% +import argparse +import json +import multiprocessing as mp +import os + +import jax.numpy as jnp +import matplotlib.dates as mdates +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from cycler import cycler +from fall_virus_inferer import FallVirusInferer +from matplotlib.backends.backend_pdf import PdfPages +from run_task import rework_initial_state + +from resp_ode import CovidSeroInitializer, MechanisticRunner +from resp_ode.model_odes.seip_model_flatten_immune_hist import seip_ode + +plt.switch_backend("agg") + + +def retrieve_inferer_obs(state, az_output_path, model_day): + state_config_path = os.path.join(az_output_path, state) + print("Retrieving " + state + "\n") + GLOBAL_CONFIG_PATH = os.path.join( + state_config_path, "config_global_used.json" + ) + TEMP_GLOBAL_CONFIG_PATH = os.path.join( + state_config_path, "temp_config_global_template.json" + ) + global_js = json.load(open(GLOBAL_CONFIG_PATH)) + global_js["NUM_STRAINS"] = 3 + global_js["NUM_WANING_COMPARTMENTS"] = 5 + global_js["WANING_TIMES"] = [70, 70, 70, 129, 0] + json.dump(global_js, open(TEMP_GLOBAL_CONFIG_PATH, "w")) + INITIALIZER_CONFIG_PATH = os.path.join( + state_config_path, "config_initializer_used.json" + ) + INFERER_CONFIG_PATH = os.path.join( + state_config_path, "config_inferer_used.json" + ) + + # sets up the initial conditions, initializer.get_initial_state() passed to runner + initializer = CovidSeroInitializer( + INITIALIZER_CONFIG_PATH, TEMP_GLOBAL_CONFIG_PATH + ) + runner = MechanisticRunner(seip_ode) + initial_state = initializer.get_initial_state() + initial_state = rework_initial_state(initial_state) + inferer = FallVirusInferer( + GLOBAL_CONFIG_PATH, INFERER_CONFIG_PATH, runner, initial_state + ) + # observed data + hosp_data_filename = "%s_hospitalization.csv" % ( + initializer.config.REGIONS[0].replace(" ", "_") + ) + hosp_data_path = os.path.join(inferer.config.HOSP_PATH, hosp_data_filename) + hosp_data = pd.read_csv(hosp_data_path) + hosp_data["date"] = pd.to_datetime(hosp_data["date"]) + # align hosp to infections assuming 7-day inf -> hosp delay + hosp_data["day"] = ( + hosp_data["date"] - pd.to_datetime(inferer.config.INIT_DATE) + ).dt.days - 7 + # only keep hosp data that aligns to our initial date + # sort ascending + hosp_data = hosp_data.loc[ + (hosp_data["day"] >= 0) & (hosp_data["day"] <= model_day) + ].sort_values(by=["day", "agegroup"], ascending=True, inplace=False) + # make hosp into day x agegroup matrix + obs_hosps = hosp_data.groupby(["day"])["hosp"].apply(np.array) + obs_hosps_days = obs_hosps.index.to_list() + obs_hosps = jnp.array(obs_hosps.to_list()) + + sero_data_filename = "%s_sero.csv" % ( + initializer.config.REGIONS[0].replace(" ", "_") + ) + sero_data_path = os.path.join(inferer.config.SERO_PATH, sero_data_filename) + sero_data = pd.read_csv(sero_data_path) + sero_data["date"] = pd.to_datetime(sero_data["date"]) + # align sero to infections assuming 14-day seroconversion delay + sero_data["day"] = ( + sero_data["date"] - pd.to_datetime(inferer.config.INIT_DATE) + ).dt.days - 14 + sero_data = sero_data.loc[ + (sero_data["day"] >= 0) & (sero_data["day"] <= model_day) + ].sort_values(by=["day", "age"], ascending=True, inplace=False) + # transform data to logit scale + sero_data["logit_rate"] = np.log( + sero_data["rate"] / (100.0 - sero_data["rate"]) + ) + # make sero into day x agegroup matrix + obs_sero_lmean = sero_data.groupby(["day"])["logit_rate"].apply(np.array) + obs_sero_days = obs_sero_lmean.index.to_list() + obs_sero_lmean = jnp.array(obs_sero_lmean.to_list()) + + var_data_filename = "%s_strain_prop.csv" % ( + initializer.config.REGIONS[0].replace(" ", "_") + ) + var_data_path = os.path.join(inferer.config.VAR_PATH, var_data_filename) + # currently working up to third strain which is XBB1 + var_data = pd.read_csv(var_data_path) + var_data = var_data[var_data["strain"] < 6] + var_data["date"] = pd.to_datetime(var_data["date"]) + var_data["day"] = ( + var_data["date"] - pd.to_datetime("2022-02-11") + ).dt.days # no shift in alignment for variants + var_data = var_data.loc[ + (var_data["day"] >= 0) & (var_data["day"] <= model_day) + ].sort_values(by=["day", "strain"], ascending=True, inplace=False) + obs_var_prop = var_data.groupby(["day"])["share"].apply(np.array) + obs_var_days = obs_var_prop.index.to_list() + obs_var_prop = jnp.array(obs_var_prop.to_list()) + obs_var_prop = obs_var_prop / jnp.sum(obs_var_prop, axis=1)[:, None] + + return ( + inferer, + runner, + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_days, + obs_var_prop, + obs_var_days, + ) + + +def retrieve_fitted_medians(state, az_output_path): + json_file = os.path.join(az_output_path, state, "checkpoint.json") + post_samp = json.load(open(json_file, "r")) + del post_samp["final_timestep_s"] + del post_samp["final_timestep_e"] + del post_samp["final_timestep_i"] + del post_samp["final_timestep_c"] + fitted_medians = { + k: jnp.median(jnp.array(v), axis=(0, 1)) for k, v in post_samp.items() + } + + return fitted_medians + + +def retrieve_timeline(state, az_output_path): + csv_file = os.path.join( + az_output_path, state, "azure_visualizer_timeline.csv" + ) + timeline = pd.read_csv(csv_file) + timeline["date"] = pd.to_datetime(timeline["date"]) + + return timeline + + +def process_plot_state(state, az_output_path): + timeline = retrieve_timeline(state, az_output_path) + fitted_medians = retrieve_fitted_medians(state, az_output_path) + fitted_medians["state"] = state + median_df = pd.DataFrame(fitted_medians, index=[state]) + ( + inferer, + runner, + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_days, + obs_var_prop, + obs_var_days, + ) = retrieve_inferer_obs(state, az_output_path) + + obs_sero = 1 / (1 + np.exp(-obs_sero_lmean)) + chain_particles = timeline["chain_particle"].unique() + date_format = mdates.DateFormatter("%b\n%y") + colors_age = ["#e41a1c", "#377eb8", "#4daf4a", "#984ea3"] + colors_strain = [ + "#1b9e77", + "#d95f02", + "#7570b3", + "#e7298a", + "#66a61e", + "#e6ab02", + ] + fig, axs = plt.subplots(3, 1) + # Configs + axs[0].xaxis.set_major_formatter(date_format) + axs[0].set_prop_cycle(cycler(color=colors_age)) + axs[0].set_title(state + ": Observed vs fitted") + axs[0].set_ylim([0.01, 500]) + axs[0].set_yscale("log") + axs[0].set_ylabel("Hospitalization") + + axs[1].xaxis.set_major_formatter(date_format) + axs[1].set_prop_cycle(cycler(color=colors_age)) + axs[1].set_ylabel("Seroprevalence") + axs[1].set_ylim([0, 1.1]) + + axs[2].set_prop_cycle(cycler(color=colors_strain[:6])) + axs[2].xaxis.set_major_formatter(date_format) + axs[2].set_ylabel("Variant proportion") + # Simulated + for i, cp in enumerate(chain_particles): + sub_df = timeline[timeline["chain_particle"] == cp] + sim_hosp = np.array( + sub_df[ + [ + "pred_hosp_0_17", + "pred_hosp_18_49", + "pred_hosp_50_64", + "pred_hosp_65+", + ] + ] + ) + sim_sero = np.array( + sub_df[["sero_0_17", "sero_18_49", "sero_50_64", "sero_65+"]] + ) + sp_columns = [x for x in sub_df.columns if "strain_proportion" in x] + sim_var_prop = np.array(sub_df[sp_columns]) + if i == 0: + dates = np.array(sub_df["date"]) + axs[0].plot( + dates, + sim_hosp, + label=["0-17", "18-49", "50-64", "65+"], + alpha=0.1, + ) + axs[2].plot( + dates, + sim_var_prop, + alpha=0.1, + label=inferer.config.STRAIN_IDX, + ) + else: + axs[0].plot(dates, sim_hosp, alpha=0.1) + axs[2].plot( + dates, + sim_var_prop, + alpha=0.1, + ) + + axs[1].plot(dates, sim_sero, alpha=0.1) + + # Observed + axs[0].plot(dates[obs_hosps_days], obs_hosps, linestyle=":") + for s, c in zip(jnp.transpose(obs_sero), colors_age): + axs[1].scatter(dates[obs_sero_days], s, color=c) + for v, c in zip(jnp.transpose(obs_var_prop), colors_strain[:6]): + axs[2].scatter(dates[obs_var_days], v, color=c, s=7) + + fig.set_size_inches(8, 10) + fig.set_dpi(300) + leg = fig.legend(loc=7) + for lh in leg.legend_handles: + lh.set_alpha(1) + return fig, median_df + + +def save_states_pdf(output_path, job_id): + job_path = os.path.join(output_path, job_id) + states = [ + d + for d in os.listdir(job_path) + if os.path.isdir(os.path.join(job_path, d)) + ] + pool = mp.Pool(5) + figs, median_dfs = zip( + *pool.map(process_plot_state, [(st, job_path) for st in states]) + ) + + pdf_pages = PdfPages( + os.path.join(job_path, "obs_vs_fitted_%s.pdf" % job_id) + ) + for f in figs: + pdf_pages.savefig(f) + plt.close(f) + pdf_pages.close() + + pool.close() + + +parser = argparse.ArgumentParser(description="Experiment Azure Launcher") +parser.add_argument( + "--job_id", + "-j", + type=str, + help="job ID of the azure job, must be unique", + required=True, +) + +if __name__ == "__main__": + args = parser.parse_args() + job_id: str = args.job_id + output_path = "/output/fifty_state_season2_5strain_2202_2404" + save_states_pdf(output_path, job_id) diff --git a/exp/fifty_state_season2_5strain_2202_2404/postprocess_script_1.py b/exp/fifty_state_season2_5strain_2202_2404/postprocess_script_1.py new file mode 100644 index 00000000..0478ef9d --- /dev/null +++ b/exp/fifty_state_season2_5strain_2202_2404/postprocess_script_1.py @@ -0,0 +1,72 @@ +import argparse +import csv +import json +import os +import warnings + +from tqdm import tqdm + +OUTPUT_PATH = "/output/fifty_state_season2_5strain_2202_2404/" + + +def collate_checkpoints_to_csv(output_path, jobid): + """Grabs all states under output_path/jobid reads in their checkpoint.json, and saves + the values as a csv + + Parameters + ---------- + output_path : str + output path as str where job is stored + jobid : str + jobid of the job run + """ + job_path = os.path.join(output_path, jobid) + output_csv = os.path.join(job_path, "checkpoints_collated.csv") + scenarios = os.listdir(job_path) + # all the scenarios have the same checkpoints, so lets just go into one of them + scenario = scenarios[0] + states = os.listdir(job_path) + with open(output_csv, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + + # Write the header row to the CSV file + writer.writerow( + ["state", "parameter_name", "chain", "sample_num", "value"] + ) + for st in tqdm(states, desc="processing states "): + print(st) + # only looking at one scenario per state + state_path = os.path.join(job_path, st, scenario) + checkpoint_path = os.path.join(state_path, "checkpoint.json") + if os.path.exists(checkpoint_path): + checkpoint = json.load(open(checkpoint_path, "r")) + # Iterate over each parameter in the JSON data + for parameter_name, values in checkpoint.items(): + # Iterate over each chain and sample number in the 2D list of values + for chain, row in enumerate(values): + for sample_num, value in enumerate(row): + # Write a row to the CSV file with state, parameter_name, chain, sample_num, and value + writer.writerow( + [st, parameter_name, chain, sample_num, value] + ) + + else: + warnings.warn( + "%s state path lacks a checkpoint.json file, check your paths or job for this state did not complete" + % state_path + ) + + +parser = argparse.ArgumentParser(description="Experiment Azure Launcher") +parser.add_argument( + "--job_id", + "-j", + type=str, + help="job ID of the azure job, must be unique", + required=True, +) + +if __name__ == "__main__": + args = parser.parse_args() + job_id: str = args.job_id + collate_checkpoints_to_csv(OUTPUT_PATH, job_id) diff --git a/exp/fifty_state_season2_5strain_2202_2404/run_task.py b/exp/fifty_state_season2_5strain_2202_2404/run_task.py new file mode 100644 index 00000000..0bdba469 --- /dev/null +++ b/exp/fifty_state_season2_5strain_2202_2404/run_task.py @@ -0,0 +1,289 @@ +# ruff: noqa: E402 +import argparse +import json +import os +import shutil +import sys + +import jax +import numpy as np + +# adding things to path since in a docker container pathing gets changed +sys.path.append("/app/") +sys.path.append("/input/exp/fifty_state_season2_5strain_2202_2404/") +print(os.getcwd()) +# sys.path.append(".") +# sys.path.append(os.getcwd()) +import jax.numpy as jnp +import pandas as pd +from fall_virus_inferer import FallVirusInferer + +from resp_ode import CovidSeroInitializer, MechanisticRunner +from resp_ode.model_odes.seip_model_flatten_immune_hist import seip_ode +from resp_ode.utils import combine_strains, combined_strains_mapping +from src.mechanistic_azure.abstract_azure_runner import AbstractAzureRunner + +jax.config.update("jax_enable_x64", True) + + +def rework_initial_state(initial_state): + """ + Take the original `initial_state` which is (4, 16, 3, 5), collapsing the + strain 0 and 1 into 0, and collapsing wane 3 and 4 into 3, resulting in + a 3-strain initial_state with (4, 8, 3, 4). + """ + hist_map, strain_map = combined_strains_mapping(1, 0, 3) + s_new_1 = combine_strains( + initial_state[0], hist_map, strain_map, 3, strain_axis=False + )[:, :, :, :] + s_new_2 = jnp.ones((4, 8, 3, 4)) * s_new_1[:, :, :, :4] + s_new = s_new_2.at[:, :, :, 3].add(s_new_1[:, :, :, 4]) + e_new = combine_strains( + initial_state[1], hist_map, strain_map, 3, strain_axis=True + )[:, :, :, :] + i_new = combine_strains( + initial_state[2], hist_map, strain_map, 3, strain_axis=True + )[:, :, :, :] + c_new = initial_state[3][:, :, :, :] + concatenate_shp = list(e_new.shape) + concatenate_shp[3] = 2 + e_new = jnp.concatenate((e_new, jnp.zeros(tuple(concatenate_shp))), axis=3) + i_new = jnp.concatenate((i_new, jnp.zeros(tuple(concatenate_shp))), axis=3) + c_new_shape = list(s_new.shape) + c_new_shape.append(i_new.shape[3]) + c_new = jnp.zeros(tuple(c_new_shape)) + # c_new = jnp.concatenate((c_new, jnp.zeros(tuple(concatenate_shp))), axis=3) + initial_state = ( + s_new[:, 0:6, ...], + e_new[:, 0:6, ...], + i_new[:, 0:6, ...], + c_new[:, 0:6, ...], + ) + return initial_state + + +def preprocess_observed_data(initializer, inferer, model_day): + """A function responsible for reading in observed data, preprocessing it, and returning + a tuple of timeseries + + Parameters + ---------- + initializer : CovidSeroInitializer + initializer used to construct initial states + inferer : MechanisticInferer + inferer used to calculate likelihood on this observed data + model_day : int + Number of days on which the model will run, and the observed data should span + + Returns + ------- + tuple[Jax.Array] + tuple of observed datasets, specifically hospitalization, serology, and variant proportion. + """ + hosp_data_filename = "%s_hospitalization.csv" % ( + initializer.config.REGIONS[0].replace(" ", "_") + ) + hosp_data_path = os.path.join(inferer.config.HOSP_PATH, hosp_data_filename) + hosp_data = pd.read_csv(hosp_data_path) + hosp_data["date"] = pd.to_datetime(hosp_data["date"]) + # special setting for HI + if state == "HI": + hosp_data.loc[ + (hosp_data["date"] > pd.to_datetime("2022-08-01")) + & (hosp_data["agegroup"] == "0-17"), + "hosp", + ] = np.nan + # align hosp to infections assuming 7-day inf -> hosp delay + hosp_data["day"] = ( + hosp_data["date"] - pd.to_datetime(inferer.config.INIT_DATE) + ).dt.days - 7 + # only keep hosp data that aligns to our initial date + # sort ascending + hosp_data = hosp_data.loc[ + (hosp_data["day"] >= 0) & (hosp_data["day"] <= model_day) + ].sort_values(by=["day", "agegroup"], ascending=True, inplace=False) + # make hosp into day x agegroup matrix + obs_hosps = hosp_data.groupby(["day"])["hosp"].apply(np.array) + obs_hosps_days = obs_hosps.index.to_list() + obs_hosps = jnp.array(obs_hosps.to_list()) + + sero_data_filename = "%s_sero.csv" % ( + initializer.config.REGIONS[0].replace(" ", "_") + ) + sero_data_path = os.path.join(inferer.config.SERO_PATH, sero_data_filename) + sero_data = pd.read_csv(sero_data_path) + sero_data["date"] = pd.to_datetime(sero_data["date"]) + # align sero to infections assuming 14-day seroconversion delay + sero_data["day"] = ( + sero_data["date"] - pd.to_datetime(inferer.config.INIT_DATE) + ).dt.days - 14 + sero_data = sero_data.loc[ + (sero_data["day"] >= 0) & (sero_data["day"] <= model_day) + ].sort_values(by=["day", "age"], ascending=True, inplace=False) + # transform data to logit scale + sero_data["logit_rate"] = np.log( + sero_data["rate"] / (100.0 - sero_data["rate"]) + ) + # make sero into day x agegroup matrix + obs_sero_lmean = sero_data.groupby(["day"])["logit_rate"].apply(np.array) + obs_sero_days = obs_sero_lmean.index.to_list() + obs_sero_lmean = jnp.array(obs_sero_lmean.to_list()) + obs_sero_lmean = obs_sero_lmean.at[np.isinf(obs_sero_lmean)].set(jnp.nan) + # set sero sd, currently this is an arbitrary tunable parameters + # dependent on sero sample size. + obs_sero_n = sero_data.groupby(["day"])["n"].apply(np.array) + obs_sero_lsd = 1.0 / jnp.sqrt(jnp.array(obs_sero_n.to_list())) + obs_sero_lsd = obs_sero_lsd.at[jnp.isnan(obs_sero_lsd)].set(0.5) + + var_data_filename = "%s_strain_prop.csv" % ( + initializer.config.REGIONS[0].replace(" ", "_") + ) + var_data_path = os.path.join(inferer.config.VAR_PATH, var_data_filename) + # currently working up to third strain which is XBB1 + var_data = pd.read_csv(var_data_path) + var_data = var_data[var_data["strain"] < 5] + var_data["date"] = pd.to_datetime(var_data["date"]) + var_data["day"] = ( + var_data["date"] - pd.to_datetime("2022-02-11") + ).dt.days # no shift in alignment for variants + var_data = var_data.loc[ + (var_data["day"] >= 0) & (var_data["day"] <= model_day) + ].sort_values(by=["day", "strain"], ascending=True, inplace=False) + obs_var_prop = var_data.groupby(["day"])["share"].apply(np.array) + obs_var_days = obs_var_prop.index.to_list() + obs_var_prop = jnp.array(obs_var_prop.to_list()) + # renormalizing the var prop + obs_var_prop = obs_var_prop / jnp.sum(obs_var_prop, axis=1)[:, None] + obs_var_sd = 80 / jnp.sqrt(jnp.sum(inferer.config.POPULATION)) + return ( + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_lsd, + obs_sero_days, + obs_var_prop, + obs_var_days, + obs_var_sd, + ) + + +class EpochOneRunner(AbstractAzureRunner): + # __init__ already implemented by the abstract case + def __init__(self, azure_output_dir): + super().__init__(azure_output_dir) + + def process_state(self, state, jobid=None, jobid_in_path=False): + model_day = 800 + # step 1: define your paths, now in the input + state_config_path = os.path.join( + "/input/exp/fifty_state_season2_5strain_2202_2404/states", + state, + ) + if jobid_in_path: + state_config_path = os.path.join( + "/input/exp/fifty_state_season2_5strain_2202_2404", + jobid, + "states", + state, + ) + # state_config_path = "exp/fifty_state_sero_second_try/" + args.state + "/" + print("Running the following state: " + state + "\n") + # global_config include definitions such as age bin bounds and strain definitions + # Any value or data structure that needs context to be interpretted is here. + GLOBAL_CONFIG_PATH = os.path.join( + state_config_path, "config_global.json" + ) + # a temporary global config that matches with original initializer + TEMP_GLOBAL_CONFIG_PATH = os.path.join( + state_config_path, "temp_config_global.json" + ) + global_js = json.load(open(GLOBAL_CONFIG_PATH)) + global_js["NUM_STRAINS"] = 3 + global_js["NUM_WANING_COMPARTMENTS"] = 5 + global_js["WANING_TIMES"] = [70, 70, 70, 129, 0] + json.dump(global_js, open(TEMP_GLOBAL_CONFIG_PATH, "w")) + + # defines the init conditions of the scenario: pop size, initial infections etc. + INITIALIZER_CONFIG_PATH = os.path.join( + state_config_path, "config_initializer.json" + ) + # defines prior __distributions__ for inferring runner variables. + INFERER_CONFIG_PATH = os.path.join( + state_config_path, "config_inferer.json" + ) + # save copies of the used config files to output for reproducibility purposes + cg_path = self.azure_output_dir + "config_global_used.json" + cinf_path = self.azure_output_dir + "config_inferer_used.json" + cini_path = self.azure_output_dir + "config_initializer_used.json" + if os.path.exists(cg_path): + os.remove(cg_path) + if os.path.exists(cinf_path): + os.remove(cinf_path) + if os.path.exists(cini_path): + os.remove(cini_path) + shutil.copy(GLOBAL_CONFIG_PATH, cg_path) + shutil.copy(INFERER_CONFIG_PATH, cinf_path) + shutil.copy(INITIALIZER_CONFIG_PATH, cini_path) + # sets up the initial conditions, initializer.get_initial_state() passed to runner + initializer = CovidSeroInitializer( + INITIALIZER_CONFIG_PATH, TEMP_GLOBAL_CONFIG_PATH + ) + runner = MechanisticRunner(seip_ode) + initial_state = initializer.get_initial_state() + initial_state = rework_initial_state(initial_state) + inferer = FallVirusInferer( + GLOBAL_CONFIG_PATH, INFERER_CONFIG_PATH, runner, initial_state + ) + ( + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_lsd, + obs_sero_days, + obs_var_prop, + obs_var_days, + obs_var_sd, + ) = preprocess_observed_data(initializer, inferer, model_day) + + inferer.infer( + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_lsd, + obs_sero_days, + obs_var_prop, + obs_var_days, + obs_var_sd, + ) + # saves all posterior samples including deterministic parameters + self.save_inference_posteriors(inferer) + self.save_inference_final_timesteps(inferer) + self.save_inference_timelines(inferer) + + +parser = argparse.ArgumentParser() +parser.add_argument( + "-s", + "--state", + type=str, + help="directory for the state to run, resembles USPS code of the state", +) + +parser.add_argument( + "-j", "--jobid", type=str, help="job-id of the state being run on Azure" +) +parser.add_argument( + "-l", "--local", action="store_false", help="scenario being run on Azure" +) + +if __name__ == "__main__": + args = parser.parse_args() + jobid = args.jobid + state = args.state + local = args.local + save_path = "/output/fifty_state_season2_5strain_2202_2404/%s/%s/" % ( + jobid, + state, + ) + runner = EpochOneRunner(save_path) + runner.process_state(state, jobid, jobid_in_path=local) diff --git a/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_global.json b/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_global.json new file mode 100644 index 00000000..68064c05 --- /dev/null +++ b/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_global.json @@ -0,0 +1,55 @@ +{ + "DEMOGRAPHIC_DATA_PATH": "/input/data/demographic-newcensus-data/", + "REGIONS": [ + "Alabama" + ], + "INIT_DATE": "2022-02-11", + "VACCINATION_SEASON_CHANGE": "2024-07-01", + "AGE_LIMITS": [ + 0, + 18, + 50, + 65 + ], + "NUM_STRAINS": 5, + "MAX_VACCINATION_COUNT": 2, + "NUM_WANING_COMPARTMENTS": 4, + "WANING_TIMES": [ + 70, + 70, + 70, + 0 + ], + "COMPARTMENT_IDX": [ + "S", + "E", + "I", + "C" + ], + "S_AXIS_IDX": [ + "age", + "hist", + "vax", + "wane" + ], + "I_AXIS_IDX": [ + "age", + "hist", + "vax", + "strain" + ], + "C_AXIS_IDX": [ + "age", + "hist", + "vax", + "wane", + "strain" + ], + "STRAIN_IDX": [ + "omicron", + "BA2BA5", + "XBB1", + "XBB2", + "JN1" + ] +} diff --git a/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_inferer.json b/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_inferer.json new file mode 100644 index 00000000..6b9728c2 --- /dev/null +++ b/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_inferer.json @@ -0,0 +1,477 @@ +{ + "SCENARIO_NAME": "test covid run for testing suite", + "CONTACT_MATRIX_PATH": "/input/data/demographic-data/contact_matrices", + "SAVE_PATH": "/output/", + "HOSP_PATH": "/input/data/hospitalization-data", + "SERO_PATH": "/input/data/serological-data/fitting-2022/", + "VAR_PATH": "/input/data/variant-data/fitting-2022/", + "VACCINATION_MODEL_DATA": "/input/data/vaccination-data/2022_02_11_to_2024_04_27", + "SEASONAL_VACCINATION": false, + "STRAIN_R0s": [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 1.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 35, + "concentration0": 65 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 2.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 30, + "concentration0": 70 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 3.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 30, + "concentration0": 70 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 2.0, + "scale": 3.0, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1 + ], + "INFECTIOUS_PERIOD": 7.0, + "EXPOSED_TO_INFECTIOUS": 3.6, + "INITIAL_INFECTIONS_SCALE": { + "distribution": "TruncatedNormal", + "params": { + "loc": 1.0, + "scale": 0.1, + "low": 0.5, + "high": 2.0 + } + }, + "INTRODUCTION_TIMES": [ + { + "distribution": "TruncatedNormal", + "params": { + "loc": 20, + "scale": 5, + "low": 10 + } + }, + { + "distribution": "TruncatedNormal", + "params": { + "loc": 230, + "scale": 5, + "low": 190 + } + }, + { + "distribution": "TruncatedNormal", + "params": { + "loc": 500, + "scale": 5, + "low": 450 + } + }, + { + "distribution": "TruncatedNormal", + "params": { + "loc": 640, + "scale": 5, + "low": 600 + } + } + ], + "CONSTANT_STEP_SIZE": 0, + "INTRODUCTION_PCTS": [ + 0.015, + 0.015, + 0.015, + 0.015 + ], + "INTRODUCTION_SCALES": [ + 18, + 18, + 18, + 18 + ], + "INTRODUCTION_AGE_MASK": [ + false, + true, + false, + false + ], + "MIN_HOMOLOGOUS_IMMUNITY": 0.40, + "WANING_PROTECTIONS": [ + 1.0, + 1.0, + 1.0, + 0.0 + ], + "STRAIN_INTERACTIONS": [ + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 44.27, + "concentration0": 66.4 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 44.27, + "concentration0": 66.4 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0, + 1.0 + ], + [ + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 55.34, + "concentration0": 55.34 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 88.54, + "concentration0": 22.34 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0, + 1.0 + ], + [ + 0.25, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 55.34, + "concentration0": 55.34 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 88.54, + "concentration0": 22.34 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0, + 1.0 + ], + [ + 0.15, + 0.25, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 55.34, + "concentration0": 55.34 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 88.54, + "concentration0": 22.34 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.5, + "scale": 0.5, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + 1.0 + ] + ], + "VACCINE_EFF_MATRIX": [ + [ + 0, + 0.35, + 0.70 + ], + [ + 0, + 0.30, + 0.60 + ], + [ + 0, + 0.25, + 0.50 + ], + [ + 0, + 0.20, + 0.40 + ], + [ + 0, + 0.095, + 0.19 + ] + ], + "BETA_TIMES": [ + 0.0 + ], + "BETA_COEFICIENTS": [ + 1.0 + ], + "SEASONALITY_AMPLITUDE": { + "distribution": "TransformedDistribution", + "params": { + "base_distribution": { + "distribution": "Beta", + "params": { + "concentration1": 72, + "concentration0": 28 + } + }, + "transforms": { + "transform": "AffineTransform", + "params": { + "loc": 0.02, + "scale": 0.20, + "domain": { + "constraint": "unit_interval", + "params": {} + } + } + } + } + }, + "SEASONALITY_SECOND_WAVE": { + "distribution": "Beta", + "params": { + "concentration1": 50, + "concentration0": 50 + } + }, + "SEASONALITY_DOMINANT_WAVE_DAY": { + "distribution": "TruncatedNormal", + "params": { + "loc": 345, + "scale": 5, + "low": 325, + "high": 365 + } + }, + "SEASONALITY_SECOND_WAVE_DAY": { + "distribution": "TruncatedNormal", + "params": { + "loc": 230, + "scale": 5, + "low": 210, + "high": 270 + } + }, + "INFERENCE_PRNGKEY": 8675314, + "INFERENCE_NUM_WARMUP": 1000, + "INFERENCE_NUM_SAMPLES": 1000, + "INFERENCE_NUM_CHAINS": 4, + "INFERENCE_PROGRESS_BAR": true, + "MAX_TREE_DEPTH": 10, + "MODEL_RAND_SEED": 8675316 +} diff --git a/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_initializer.json b/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_initializer.json new file mode 100644 index 00000000..18486763 --- /dev/null +++ b/exp/fifty_state_season2_5strain_2202_2404/template_configs/config_initializer.json @@ -0,0 +1,50 @@ +{ + "INITALIZER_NAME": "base initializer covid", + "SEROLOGICAL_DATA_PATH": "/input/data/serological-data/sero-initializer-csvs", + "CONTACT_MATRIX_PATH": "/input/data/demographic-data/contact_matrices", + "POP_SIZE": 1000000, + "INITIAL_INFECTIONS": 40000, + "WANING_PROTECTIONS": [ + 1.0, + 1.0, + 1.0, + 0.0, + 0.0 + ], + "INFECTIOUS_PERIOD": 7.0, + "EXPOSED_TO_INFECTIOUS": 3.6, + "STRAIN_INTERACTIONS": [ + [ + 1.0, + 0.7, + 0.49 + ], + [ + 0.7, + 1.0, + 0.7 + ], + [ + 0.49, + 0.7, + 1.0 + ] + ], + "VACCINE_EFF_MATRIX": [ + [ + 0, + 0.40, + 0.80 + ], + [ + 0, + 0.35, + 0.70 + ], + [ + 0, + 0.30, + 0.60 + ] + ] +} diff --git a/exp/projections_ihr2_2404_2507/template_configs/scenario_template.json b/exp/projections_ihr2_2404_2507/template_configs/scenario_template.json index cb387327..bcf0e886 100644 --- a/exp/projections_ihr2_2404_2507/template_configs/scenario_template.json +++ b/exp/projections_ihr2_2404_2507/template_configs/scenario_template.json @@ -435,10 +435,12 @@ ] ], "BETA_TIMES": [ - 0.0, 70.0 + 0.0, + 70.0 ], "BETA_COEFICIENTS": [ - 1.0, 1.0 + 1.0, + 1.0 ], "SEASONALITY_AMPLITUDE": { "distribution": "TransformedDistribution", From 894a1126daeebb6de1ab685c0b2f6779aa6efc00 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Tue, 1 Oct 2024 16:51:13 +0000 Subject: [PATCH 4/5] checkpoint, fitting now runs with the revamped experiment --- .../fall_virus_inferer.py | 87 ++++++++++++++++--- .../experiment_launcher_azure.py | 2 +- src/mechanistic_azure/experiment_setup.py | 8 +- 3 files changed, 79 insertions(+), 18 deletions(-) diff --git a/exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py b/exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py index 37fdfb8a..4c3335b9 100644 --- a/exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py +++ b/exp/fifty_state_season2_5strain_2202_2404/fall_virus_inferer.py @@ -1,5 +1,6 @@ import os +import jax import jax.numpy as jnp import numpyro import numpyro.distributions as Dist @@ -34,8 +35,8 @@ class FallVirusInferer(MechanisticInferer): "SEASONALITY_SECOND_WAVE", "SEASONALITY_DOMINANT_WAVE_DAY", "SEASONALITY_SECOND_WAVE_DAY", - "SEASONALITY_SHIFT", "MIN_HOMOLOGOUS_IMMUNITY", + "WANING_RATES", ] def load_vaccination_model(self): @@ -94,12 +95,40 @@ def infer( obs_var_prop=obs_var_prop, obs_var_days=obs_var_days, obs_var_sd=obs_var_sd, + tf=max(obs_hosps_days) + 1, ) self.inference_algo.print_summary() self.infer_complete = True self.inference_timesteps = max(obs_hosps_days) + 1 return self.inference_algo + def strain_interaction_to_cross_immunity( + self, num_strains: int, strain_interactions: jax.Array + ) -> jax.Array: + """Because we are overriding the definitions of immune history, the model must + contain its own method for generating the cross immunity matrix, not relying on + utils.strain_interaction_to_cross_immunity() which has complex logic + to convert the strain interactions matrix to a cross immunity matrix + + Parameters + ---------- + num_strains : int + number of strains in the model + strain_interactions : jax.Array + the strain interactions matrix + + Returns + ------- + jax.Array + the cross immunity matrix which is similar to the strain + matrix but with an added 0 layer for no previous exposure + """ + cim = jnp.hstack( + (jnp.array([[0.0]] * num_strains), strain_interactions), + ) + + return cim + def generate_downstream_parameters(self, parameters: dict) -> dict: """An Override for the generate downstream parameters in order to lock in transmisibility of certain strains to dependent on @@ -129,7 +158,42 @@ def generate_downstream_parameters(self, parameters: dict) -> dict: an appended onto version of `parameters` with additional downstream parameters added. """ - parameters = super().generate_downstream_parameters(parameters) + parameters[ + "CROSSIMMUNITY_MATRIX" + ] = self.strain_interaction_to_cross_immunity( + parameters["NUM_STRAINS"], + parameters["STRAIN_INTERACTIONS"], + ) + beta = parameters["STRAIN_R0s"] / parameters["INFECTIOUS_PERIOD"] + gamma = 1 / parameters["INFECTIOUS_PERIOD"] + sigma = 1 / parameters["EXPOSED_TO_INFECTIOUS"] + external_i_function_prefilled = jax.tree_util.Partial( + self.external_i, + introduction_times=parameters["INTRODUCTION_TIMES"], + introduction_scales=parameters["INTRODUCTION_SCALES"], + introduction_pcts=parameters["INTRODUCTION_PCTS"], + ) + # override the seasonality prefill since we are using alternative definition of seasonality + seasonality_function_prefilled = jax.tree_util.Partial( + self.seasonality, + seasonality_amplitude=parameters["SEASONALITY_AMPLITUDE"], + seasonality_second_wave=parameters["SEASONALITY_SECOND_WAVE"], + dominant_wave_day=parameters["SEASONALITY_DOMINANT_WAVE_DAY"], + second_wave_day=parameters["SEASONALITY_SECOND_WAVE_DAY"], + ) + parameters = dict( + parameters, + **{ + "BETA": beta, + "SIGMA": sigma, + "GAMMA": gamma, + "EXTERNAL_I": external_i_function_prefilled, + "VACCINATION_RATES": self.vaccination_rate, + "BETA_COEF": self.beta_coef, + "SEASONAL_VACCINATION_RESET": self.seasonal_vaccination_reset, + "SEASONALITY": seasonality_function_prefilled, + } + ) parameters["STRAIN_R0s"] = jnp.array( [ parameters["STRAIN_R0s"][0], @@ -389,16 +453,15 @@ def run_simulation(self, tf): def likelihood( self, - obs_hosps=None, - obs_hosps_days=None, - obs_sero_lmean=None, - obs_sero_lsd=None, - obs_sero_days=None, - obs_var_prop=None, - obs_var_days=None, - obs_var_sd=None, - tf=None, - infer_mode=True, + obs_hosps, + obs_hosps_days, + obs_sero_lmean, + obs_sero_lsd, + obs_sero_days, + obs_var_prop, + obs_var_days, + obs_var_sd, + tf, ): """ overridden likelihood that takes as input weekly hosp data starting from self.config.INIT_DATE diff --git a/src/mechanistic_azure/experiment_launcher_azure.py b/src/mechanistic_azure/experiment_launcher_azure.py index 76029c0c..11416041 100644 --- a/src/mechanistic_azure/experiment_launcher_azure.py +++ b/src/mechanistic_azure/experiment_launcher_azure.py @@ -5,7 +5,7 @@ import argparse -from .azure_utilities import AzureExperimentLauncher +from azure_utilities import AzureExperimentLauncher # specify job ID, cant already exist diff --git a/src/mechanistic_azure/experiment_setup.py b/src/mechanistic_azure/experiment_setup.py index d4375c25..1299aef8 100644 --- a/src/mechanistic_azure/experiment_setup.py +++ b/src/mechanistic_azure/experiment_setup.py @@ -15,11 +15,9 @@ # these are the configs that will be copied into each state-level directory # their REGIONS key will be modified to match the state they work with. CONFIG_MOLDS = [ - "exp/example_azure_experiment/example_template_configs/config_global.json", - "exp/example_azure_experiment/example_template_configs/config_inferer_covid.json", - "exp/example_azure_experiment/example_template_configs/config_initializer_covid.json", - "exp/example_azure_experiment/example_template_configs/config_runner_covid.json", - "exp/example_azure_experiment/example_template_configs/config_interpreter_covid.json", + "exp/fifty_state_season2_5strain_2202_2404/template_configs/config_global.json", + "exp/fifty_state_season2_5strain_2202_2404/template_configs/config_inferer.json", + "exp/fifty_state_season2_5strain_2202_2404/template_configs/config_initializer.json", ] EXPERIMENT_DIRECTORY = "exp" From 55ffa939b1840eb116826e958a2640da0eda667f Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Thu, 3 Oct 2024 23:43:08 +0000 Subject: [PATCH 5/5] a quick fix on the shinyapp to get it working again --- shiny_visualizers/azure_visualizer.py | 2 +- shiny_visualizers/shiny_utils.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/shiny_visualizers/azure_visualizer.py b/shiny_visualizers/azure_visualizer.py index a85959f9..6b3f0060 100644 --- a/shiny_visualizers/azure_visualizer.py +++ b/shiny_visualizers/azure_visualizer.py @@ -22,7 +22,7 @@ # this will reduce the time it takes to load the azure connection, but only shows # one experiment worth of data, which may be what you want... # leave empty ("") to explore all experiments -PRE_FILTER_EXPERIMENTS = "" +PRE_FILTER_EXPERIMENTS = "fifty_state_season2_5strain_2202_2404" # when loading the overview timelines csv for each run, columns # are expected to have names corresponding to the type of plot they create # vaccination_0_17 specifies the vaccination_ plot type, multiple columns may share diff --git a/shiny_visualizers/shiny_utils.py b/shiny_visualizers/shiny_utils.py index 14c60336..bb588500 100644 --- a/shiny_visualizers/shiny_utils.py +++ b/shiny_visualizers/shiny_utils.py @@ -12,7 +12,7 @@ from plotly.subplots import make_subplots from mechanistic_azure.azure_utilities import download_directory_from_azure -from resp_ode.utils import flatten_list_parameters +from resp_ode.utils import drop_keys_with_substring, flatten_list_parameters class Node: @@ -258,6 +258,8 @@ def load_checkpoint_inference_chains( # any sampled parameters created via numpyro.plate will mess up the data # flatten plated parameters into separate keys posteriors: dict[str, list] = flatten_list_parameters(posteriors) + # drop any final_timestep variables if they exist within the posteriors + posteriors = drop_keys_with_substring(posteriors, "final_timestep") num_sampled_parameters = len(posteriors.keys()) # we want a mostly square subplot, so lets sqrt and take floor/ceil to deal with odd numbers num_rows = math.isqrt(num_sampled_parameters) @@ -331,6 +333,8 @@ def load_checkpoint_inference_correlations( posteriors = { key: np.array(matrix).flatten() for key, matrix in posteriors.items() } + # drop any final_timestep parameters in case they snuck in + posteriors = drop_keys_with_substring(posteriors, "final_timestep") # Compute the correlation matrix, reverse it so diagonal starts @ top left correlation_matrix = pd.DataFrame(posteriors).corr()[::-1]