Skip to content

Commit

Permalink
improving violin / prior dist visualizations (#299)
Browse files Browse the repository at this point in the history
* checkpoint, upgrading violin and priors visualizers on vis_utils for use in shiny apps

* bugfixing 2+ dimensional inputs to plot_violint_plots() and dict.keys() bugfix

* plot_violin_plots() now works with either priors OR posteriors specified

* fixing mypy
  • Loading branch information
arik-shurygin authored Jan 15, 2025
1 parent 9cc21ae commit 3525a4a
Showing 1 changed file with 152 additions and 34 deletions.
186 changes: 152 additions & 34 deletions src/dynode/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from jax import Array
from jax.random import PRNGKey
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap
Expand Down Expand Up @@ -476,6 +477,55 @@ def plot_mcmc_chains(
return fig


def _sample_prior_distributions(priors, num_samples) -> dict[str, Array]:
"""Sample numpyro.distributions `num_samples` times.
Parameters
----------
priors : dict[str, Any]
A dictionary containing keys of different parameter
names and values of any type.
num_samples : int
number of times to sample numpyro.distribution objects.
Returns
-------
dict[str, jax.Array]
Numpyro sample site name with jax.Array(shape=(num_samples,)) for each
numpyro.distribution found within `priors`.
Notes
--------
Return dict key names follow the same naming convention as when sampling.
Meaning that distributions within lists or matricies have their
index stored as a list of _i suffix at the end of their name.
"""
dist_only = {}
d = identify_distribution_indexes(priors)
# filter down to just the distribution objects
for dist_name, locator_dct in d.items():
parameter_name = locator_dct["sample_name"]
assert isinstance(parameter_name, str)
parameter_idx = locator_dct["sample_idx"]
assert isinstance(parameter_idx, tuple) or parameter_idx is None
# if the sample is on its own, not nested in a list, sample_idx is none
if parameter_idx is None:
dist_only[parameter_name] = priors[parameter_name]
# otherwise this sample is nested in a list and should be retrieved
else:
temp = priors[parameter_name]
# go into multi-dimensional matricies one index at a time
for i in parameter_idx:
temp = temp[i]
dist_only[dist_name] = temp
sampled_priors = {}
for param, dist in dist_only.items():
sampled_priors[param] = dist.sample(
PRNGKey(0), sample_shape=(num_samples,)
)
return sampled_priors


def plot_prior_distributions(
priors: dict[str, Any],
matplotlib_style: list[str]
Expand All @@ -484,6 +534,11 @@ def plot_prior_distributions(
],
num_samples=5000,
hist_kwargs={"bins": 50, "density": True},
median_line_kwargs={
"linestyle": "dotted",
"linewidth": 3,
"label": "prior median",
},
) -> plt.Figure:
"""Given a dictionary of parameter keys and possibly values of
numpyro.distribution objects, samples them a number of times
Expand All @@ -510,26 +565,8 @@ def plot_prior_distributions(
matplotlib figure that is roughly square containing all distribution
keys found within priors.
"""
dist_only = {}
d = identify_distribution_indexes(priors)
# filter down to just the distribution objects
for dist_name, locator_dct in d.items():
parameter_name = locator_dct["sample_name"]
assert isinstance(parameter_name, str)
parameter_idx = locator_dct["sample_idx"]
assert isinstance(parameter_idx, tuple) or parameter_idx is None

# if the sample is on its own, not nested in a list, sample_idx is none
if parameter_idx is None:
dist_only[parameter_name] = priors[parameter_name]
# otherwise this sample is nested in a list and should be retrieved
else:
# go in index by index to access multi-dimensional lists
temp = priors[parameter_name]
for i in parameter_idx:
temp = temp[i]
dist_only[dist_name] = temp
param_names = list(dist_only.keys())
sampled_priors = _sample_prior_distributions(priors, num_samples)
param_names = list(sampled_priors.keys())
num_params = len(param_names)
if num_params == 0:
raise VisualizationError(
Expand All @@ -551,21 +588,10 @@ def plot_prior_distributions(
for i, param_name in enumerate(param_names):
ax: Axes = axs_flat[i]
ax.set_title(param_name)
dist = dist_only[param_name]
samples = dist.sample(PRNGKey(0), sample_shape=(num_samples,))
samples = sampled_priors[param_name]
ax.hist(samples, **hist_kwargs)
ax.axvline(
samples.mean(),
linestyle="dashed",
linewidth=1,
label="mean",
)
ax.axvline(
jnp.median(samples).item(0), # This was an Array
linestyle="dotted",
linewidth=3,
label="median",
)
ax.axvline(float(jnp.median(samples)), **median_line_kwargs)
# testing
# Turn off any unused subplots
for j in range(i + 1, len(axs_flat)):
axs_flat[j].axis("off")
Expand All @@ -574,3 +600,95 @@ def plot_prior_distributions(
fig.suptitle("Prior Distributions Visualized, n=%s" % num_samples)
plt.tight_layout()
return fig


def plot_violin_plots(
priors: dict[str, list] | None = None,
posteriors: dict[str, list] | None = None,
matplotlib_style: list[str]
| str = [
"seaborn-v0_8-colorblind",
],
):
if priors is None and posteriors is None:
raise VisualizationError(
"must provide either a dictionary of priors or posteriors"
)
# we are given that both are not none, so get num_params from one of them
if posteriors is not None:
num_params = len(posteriors.keys())
elif priors is not None:
num_params = len(priors.keys())
# Calculate the number of rows and columns for a square-ish layout
num_cols = int(np.ceil(np.sqrt(num_params)))
num_rows = int(np.ceil(num_params / num_cols))

with plt.style.context(matplotlib_style):
fig, axs = plt.subplots(
num_rows,
num_cols,
figsize=(3 * num_cols, 3 * num_rows),
squeeze=False,
)
axs = axs.flatten()

df = pd.DataFrame()
if priors is not None:
for param, values in priors.items():
df_param = pd.DataFrame()
# flatten any chains if they leaked in
df_param["values"] = np.array(values).flatten()
df_param["type"] = "prior"
df_param["param"] = param
df = pd.concat([df, df_param], ignore_index=True, axis=0)
if posteriors is not None:
for param, values in posteriors.items():
df_param = pd.DataFrame()
# flatten any chains if they leaked in
df_param["values"] = np.array(values).flatten()
df_param["type"] = "posterior"
df_param["param"] = param
# this is necessary to make sure there are always two columns
# of violin plots, including when a posterior does not have an
# associated prior
if priors is not None and param not in priors.keys():
filler_row = pd.DataFrame(
{"values": [np.nan], "type": "prior", "param": param}
)
df_param = pd.concat(
[filler_row, df_param], ignore_index=True, axis=0
)
df = pd.concat([df, df_param], ignore_index=True, axis=0)

# parameters that share a first word will be colored the same for interpretability
unique_first_words = set(
[param.split("_")[0] for param in df["param"].unique()]
)
color_palette = sns.color_palette("Set2", n_colors=len(unique_first_words))
color_dict = dict(zip(unique_first_words, color_palette))

# Iterate over the parameters and create violin plots
for i, param in enumerate(df["param"].unique()):
ax: Axes = axs[i]
# if priors is not None and param in priors.keys():
# sns.violinplot(y=priors[param], ax=ax, alpha=0.5, label="prior")
sns.violinplot(
data=df.loc[df["param"] == param],
x="type",
y="values",
ax=ax,
color=color_dict[param.split("_")[0]],
)
ax.set_title(param)
ax.set_ylabel("")
ax.set_xlabel("")

# Remove empty subplots if necessary
if num_params < num_rows * num_cols:
for i in range(num_params, num_rows * num_cols):
fig.delaxes(axs[i])
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc="outside upper right")
fig.suptitle("Violin Plot of Parameters")
fig.tight_layout()
return fig

0 comments on commit 3525a4a

Please sign in to comment.