From 96d62f6d093d9ae652a0e7f5d83c812d010eb73d Mon Sep 17 00:00:00 2001 From: Ariel Shurygin <39861882+arik-shurygin@users.noreply.github.com> Date: Tue, 8 Oct 2024 11:41:23 -0700 Subject: [PATCH] a quick fix on the azure shiny app to get it working again (#262) * a quick fix on the shinyapp to get it working again * forgot to turn off pre-filtering while testing * helper function was not uploaded --- shiny_visualizers/shiny_utils.py | 6 +++++- src/resp_ode/utils.py | 21 +++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/shiny_visualizers/shiny_utils.py b/shiny_visualizers/shiny_utils.py index 14c60336..bb588500 100644 --- a/shiny_visualizers/shiny_utils.py +++ b/shiny_visualizers/shiny_utils.py @@ -12,7 +12,7 @@ from plotly.subplots import make_subplots from mechanistic_azure.azure_utilities import download_directory_from_azure -from resp_ode.utils import flatten_list_parameters +from resp_ode.utils import drop_keys_with_substring, flatten_list_parameters class Node: @@ -258,6 +258,8 @@ def load_checkpoint_inference_chains( # any sampled parameters created via numpyro.plate will mess up the data # flatten plated parameters into separate keys posteriors: dict[str, list] = flatten_list_parameters(posteriors) + # drop any final_timestep variables if they exist within the posteriors + posteriors = drop_keys_with_substring(posteriors, "final_timestep") num_sampled_parameters = len(posteriors.keys()) # we want a mostly square subplot, so lets sqrt and take floor/ceil to deal with odd numbers num_rows = math.isqrt(num_sampled_parameters) @@ -331,6 +333,8 @@ def load_checkpoint_inference_correlations( posteriors = { key: np.array(matrix).flatten() for key, matrix in posteriors.items() } + # drop any final_timestep parameters in case they snuck in + posteriors = drop_keys_with_substring(posteriors, "final_timestep") # Compute the correlation matrix, reverse it so diagonal starts @ top left correlation_matrix = pd.DataFrame(posteriors).corr()[::-1] diff --git a/src/resp_ode/utils.py b/src/resp_ode/utils.py index 4929a3f8..4597d80c 100644 --- a/src/resp_ode/utils.py +++ b/src/resp_ode/utils.py @@ -1025,6 +1025,27 @@ def flatten_list_parameters( return return_dict +def drop_keys_with_substring(dct: dict[str], drop_s: str): + """A simple helper function designed to drop keys from a dictionary if they contain some substring + + Parameters + ---------- + dct : dict[str, Any] + a dictionary with string keys + drop_s : str + keys containing `drop_s` as a substring will be dropped + + Returns + ------- + dict[str, any] + dct with keys containing drop_s removed, otherwise untouched. + """ + keys_to_drop = [key for key in dct.keys() if drop_s in key] + for key in keys_to_drop: + del dct[key] + return dct + + # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ # DEATH CALCULATION CODE # @@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@