From 8eb45f3a1e4acc77258cbc7a909db6794b054c52 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin <39861882+arik-shurygin@users.noreply.github.com> Date: Wed, 6 Nov 2024 09:09:19 -0800 Subject: [PATCH] Priors visualizer (#278) * checkpoint, adding overview plot done in plt instead of plotly * checkpoint fixing the overview plots and adding the pairwise code * updating comments to match line_length limits, adding mcmc chain plot * checkpoint, integrating vis_utils into default behavior of abstract_azure_runner * increasing fig size of overview * changing size of the correlation_pairs plot * tight bounding boxes to avoid text cutoff * adding plotly back since it is still used by the azure visualizer for now * checkpoint, first draft * adding a viz for prior distributions * lowering number of samples for prior distributions --- .../shiny_visualizers/azure_visualizer.py | 29 ++++- .../shiny_visualizers/shiny_utils.py | 34 ++++++ src/resp_ode/config.py | 3 + src/resp_ode/vis_utils.py | 109 +++++++++++++++++- 4 files changed, 171 insertions(+), 4 deletions(-) diff --git a/src/mechanistic_azure/shiny_visualizers/azure_visualizer.py b/src/mechanistic_azure/shiny_visualizers/azure_visualizer.py index c2ddca03..9324ec07 100644 --- a/src/mechanistic_azure/shiny_visualizers/azure_visualizer.py +++ b/src/mechanistic_azure/shiny_visualizers/azure_visualizer.py @@ -22,9 +22,7 @@ # this will reduce the time it takes to load the azure connection, but only shows # one experiment worth of data, which may be what you want... # leave empty ("") to explore all experiments -PRE_FILTER_EXPERIMENTS = ( - "example_azure_experiment" # fifty_state_season2_5strain_2202_2404 -) +PRE_FILTER_EXPERIMENTS = "" # when loading the overview timelines csv for each run, columns # are expected to have names corresponding to the type of plot they create # vaccination_0_17 specifies the vaccination_ plot type, multiple columns may share @@ -170,6 +168,12 @@ "Sample Violin Plots", output_widget("plot_sample_violins"), ), + ui.nav_panel( + "Config Visualizer", + ui.output_plot( + "plot_prior_distributions", width=1600, height=1600 + ), + ), ), ), ) @@ -369,6 +373,25 @@ def plot_sample_correlations(): print("displaying correlations plot") return fig + @output(id="plot_prior_distributions") + @render.plot + @reactive.event(input.action_button) + def plot_prior_distributions(): + exp = input.experiment() + job_id = input.job_id() + states = input.states() + scenario = input.scenario() + theme = input.dark_mode() + theme = sutils.shiny_to_matplotlib_theme(theme) + cache_paths = sutils.get_azure_files( + exp, job_id, states, scenario, azure_client, SHINY_CACHE_PATH + ) + # we have the figure, now update the light/dark mode depending on the switch + fig = sutils.load_prior_distributions_plot(cache_paths[0], theme) + # we have the figure, now update the light/dark mode depending on the switch + print("displaying prior distributions") + return fig + @output(id="plot_sample_violins") @render_widget @reactive.event(input.action_button) diff --git a/src/mechanistic_azure/shiny_visualizers/shiny_utils.py b/src/mechanistic_azure/shiny_visualizers/shiny_utils.py index 872cba6a..717c2fc3 100644 --- a/src/mechanistic_azure/shiny_visualizers/shiny_utils.py +++ b/src/mechanistic_azure/shiny_visualizers/shiny_utils.py @@ -19,6 +19,7 @@ from tqdm import tqdm from mechanistic_azure.azure_utilities import download_directory_from_azure +from resp_ode import Config, vis_utils from resp_ode.utils import drop_keys_with_substring, flatten_list_parameters @@ -309,6 +310,22 @@ def load_checkpoint_inference_chains( return fig +def load_prior_distributions_plot(cache_path, matplotlib_theme): + path = os.path.join(cache_path, "config_inferer_used.json") + if os.path.exists(path): + config = Config(open(path).read()) + styles = ["seaborn-v0_8-colorblind", matplotlib_theme] + fig = vis_utils.plot_prior_distributions( + config.asdict(), matplotlib_style=styles + ) + else: + raise FileNotFoundError( + "%s does not exist, either the experiment did " + "not save a config used or loading files failed" % path + ) + return fig + + def load_checkpoint_inference_correlations( cache_path, overview_subplot_size: int, @@ -855,3 +872,20 @@ def shiny_to_plotly_theme(shiny_theme: str): plotly theme as str, used in `fig.update_layout(template=theme)` """ return "plotly_%s" % (shiny_theme if shiny_theme == "dark" else "white") + + +def shiny_to_matplotlib_theme(shiny_theme: str): + """shiny themes are "dark" and "light", plotly themes are + "plotly_dark" and "plotly_white", this function converts from shiny to plotly theme names + + Parameters + ---------- + shiny_theme : str + shiny theme as str + + Returns + ------- + str + plotly theme as str, used in `fig.update_layout(template=theme)` + """ + return "dark_background" if shiny_theme == "dark" else "ggplot" diff --git a/src/resp_ode/config.py b/src/resp_ode/config.py index c52391ef..b6db54bc 100644 --- a/src/resp_ode/config.py +++ b/src/resp_ode/config.py @@ -56,6 +56,9 @@ def add_file(self, config_json_str): self.set_downstream_parameters() return self + def asdict(self): + return self.__dict__ + def convert_types(self, config): """ takes a dictionary of config parameters, consults the PARAMETERS global list and attempts to convert the type diff --git a/src/resp_ode/vis_utils.py b/src/resp_ode/vis_utils.py index ea6424db..6428986a 100644 --- a/src/resp_ode/vis_utils.py +++ b/src/resp_ode/vis_utils.py @@ -1,13 +1,23 @@ """A series of utility functions for generating visualizations for the model""" +import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns +from jax.random import PRNGKey from matplotlib.axes import Axes from matplotlib.colors import LinearSegmentedColormap -from .utils import drop_keys_with_substring, flatten_list_parameters +from .utils import ( + drop_keys_with_substring, + flatten_list_parameters, + identify_distribution_indexes, +) + + +class VisualizationError(Exception): + pass def _cleanup_and_normalize_timelines( @@ -460,3 +470,100 @@ def plot_mcmc_chains( handles, labels = ax.get_legend_handles_labels() fig.legend(handles, labels, loc="outside upper center") return fig + + +def plot_prior_distributions( + priors: dict[str], + matplotlib_style: list[str] + | str = [ + "seaborn-v0_8-colorblind", + ], + num_samples=5000, + hist_kwargs={"bins": 50, "density": True}, +) -> plt.Figure: + """Given a dictionary of parameter keys and possibly values of + numpyro.distribution objects, samples them a number of times + and returns a plot of those samples to help + visualize the range of values taken by that prior distribution. + + Parameters + ---------- + priors : dict[str: Any] + a dictionary with str keys possibly containing distribution + objects as values. Each key with a distribution object type + key will be included in the plot + matplotlib_style : list[str] | str, optional + matplotlib style to plot in by default ["seaborn-v0_8-colorblind"] + num_samples: int, optional + the number of times to sample each distribution, mild impact on + figure performance. By default 50000 + hist_kwargs: dict[str: Any] + additional kwargs passed to plt.hist(), by default {"bins": 50} + + Returns + ------- + plt.Figure + matplotlib figure that is roughly square containing all distribution + keys found within priors. + """ + dist_only = {} + d = identify_distribution_indexes(priors) + # filter down to just the distribution objects + for dist_name, locator_dct in d.items(): + parameter_name = locator_dct["sample_name"] + parameter_idx = locator_dct["sample_idx"] + # if the sample is on its own, not nested in a list, sample_idx is none + if parameter_idx is None: + dist_only[parameter_name] = priors[parameter_name] + # otherwise this sample is nested in a list and should be retrieved + else: + # go in index by index to access multi-dimensional lists + temp = priors[parameter_name] + for i in parameter_idx: + temp = temp[i] + dist_only[dist_name] = temp + param_names = list(dist_only.keys()) + num_params = len(param_names) + if num_params == 0: + raise VisualizationError( + "Attempted to visualize a config without any distributions" + ) + # Calculate the number of rows and columns for a square-ish layout + num_cols = int(np.ceil(np.sqrt(num_params))) + num_rows = int(np.ceil(num_params / num_cols)) + with plt.style.context(matplotlib_style): + fig, axs = plt.subplots( + num_rows, + num_cols, + figsize=(3 * num_cols, 3 * num_rows), + squeeze=False, + ) + # Flatten the axis array for easy indexing + axs_flat = axs.flatten() + # Loop over each parameter and sample + for i, param_name in enumerate(param_names): + ax: Axes = axs_flat[i] + ax.set_title(param_name) + dist = dist_only[param_name] + samples = dist.sample(PRNGKey(0), sample_shape=(num_samples,)) + ax.hist(samples, **hist_kwargs) + ax.axvline( + samples.mean(), + linestyle="dashed", + linewidth=1, + label="mean", + ) + ax.axvline( + jnp.median(samples), + linestyle="dotted", + linewidth=3, + label="median", + ) + # Turn off any unused subplots + for j in range(i + 1, len(axs_flat)): + axs_flat[j].axis("off") + handles, labels = ax.get_legend_handles_labels() + fig.legend(handles, labels, loc="outside upper right") + fig.suptitle("Prior Distributions Visualized, n=%s" % num_samples) + plt.tight_layout() + return fig