Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improving violin / prior dist visualizations #299

Merged
merged 5 commits into from
Jan 15, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 152 additions & 34 deletions src/dynode/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from jax import Array
from jax.random import PRNGKey
from matplotlib.axes import Axes
from matplotlib.colors import LinearSegmentedColormap
Expand Down Expand Up @@ -476,6 +477,55 @@ def plot_mcmc_chains(
return fig


def _sample_prior_distributions(priors, num_samples) -> dict[str, Array]:
"""Sample numpyro.distributions `num_samples` times.

Parameters
----------
priors : dict[str, Any]
A dictionary containing keys of different parameter
names and values of any type.
num_samples : int
number of times to sample numpyro.distribution objects.

Returns
-------
dict[str, jax.Array]
Numpyro sample site name with jax.Array(shape=(num_samples,)) for each
numpyro.distribution found within `priors`.

Notes
--------
Return dict key names follow the same naming convention as when sampling.
Meaning that distributions within lists or matricies have their
index stored as a list of _i suffix at the end of their name.
"""
dist_only = {}
d = identify_distribution_indexes(priors)
# filter down to just the distribution objects
for dist_name, locator_dct in d.items():
parameter_name = locator_dct["sample_name"]
assert isinstance(parameter_name, str)
parameter_idx = locator_dct["sample_idx"]
assert isinstance(parameter_idx, tuple) or parameter_idx is None
# if the sample is on its own, not nested in a list, sample_idx is none
if parameter_idx is None:
dist_only[parameter_name] = priors[parameter_name]
# otherwise this sample is nested in a list and should be retrieved
else:
temp = priors[parameter_name]
# go into multi-dimensional matricies one index at a time
for i in parameter_idx:
temp = temp[i]
dist_only[dist_name] = temp
sampled_priors = {}
for param, dist in dist_only.items():
sampled_priors[param] = dist.sample(
PRNGKey(0), sample_shape=(num_samples,)
)
return sampled_priors


def plot_prior_distributions(
priors: dict[str, Any],
matplotlib_style: list[str]
Expand All @@ -484,6 +534,11 @@ def plot_prior_distributions(
],
num_samples=5000,
hist_kwargs={"bins": 50, "density": True},
median_line_kwargs={
"linestyle": "dotted",
"linewidth": 3,
"label": "prior median",
},
) -> plt.Figure:
"""Given a dictionary of parameter keys and possibly values of
numpyro.distribution objects, samples them a number of times
Expand All @@ -510,26 +565,8 @@ def plot_prior_distributions(
matplotlib figure that is roughly square containing all distribution
keys found within priors.
"""
dist_only = {}
d = identify_distribution_indexes(priors)
# filter down to just the distribution objects
for dist_name, locator_dct in d.items():
parameter_name = locator_dct["sample_name"]
assert isinstance(parameter_name, str)
parameter_idx = locator_dct["sample_idx"]
assert isinstance(parameter_idx, tuple) or parameter_idx is None

# if the sample is on its own, not nested in a list, sample_idx is none
if parameter_idx is None:
dist_only[parameter_name] = priors[parameter_name]
# otherwise this sample is nested in a list and should be retrieved
else:
# go in index by index to access multi-dimensional lists
temp = priors[parameter_name]
for i in parameter_idx:
temp = temp[i]
dist_only[dist_name] = temp
param_names = list(dist_only.keys())
sampled_priors = _sample_prior_distributions(priors, num_samples)
param_names = list(sampled_priors.keys())
num_params = len(param_names)
if num_params == 0:
raise VisualizationError(
Expand All @@ -551,21 +588,10 @@ def plot_prior_distributions(
for i, param_name in enumerate(param_names):
ax: Axes = axs_flat[i]
ax.set_title(param_name)
dist = dist_only[param_name]
samples = dist.sample(PRNGKey(0), sample_shape=(num_samples,))
samples = sampled_priors[param_name]
ax.hist(samples, **hist_kwargs)
ax.axvline(
samples.mean(),
linestyle="dashed",
linewidth=1,
label="mean",
)
ax.axvline(
jnp.median(samples).item(0), # This was an Array
linestyle="dotted",
linewidth=3,
label="median",
)
ax.axvline(float(jnp.median(samples)), **median_line_kwargs)
# testing
# Turn off any unused subplots
for j in range(i + 1, len(axs_flat)):
axs_flat[j].axis("off")
Expand All @@ -574,3 +600,95 @@ def plot_prior_distributions(
fig.suptitle("Prior Distributions Visualized, n=%s" % num_samples)
plt.tight_layout()
return fig


def plot_violin_plots(
priors: dict[str, list] | None = None,
posteriors: dict[str, list] | None = None,
matplotlib_style: list[str]
| str = [
"seaborn-v0_8-colorblind",
],
):
if priors is None and posteriors is None:
raise VisualizationError(
"must provide either a dictionary of priors or posteriors"
)
# we are given that both are not none, so get num_params from one of them
if posteriors is not None:
num_params = len(posteriors.keys())
elif priors is not None:
num_params = len(priors.keys())
# Calculate the number of rows and columns for a square-ish layout
num_cols = int(np.ceil(np.sqrt(num_params)))
num_rows = int(np.ceil(num_params / num_cols))

with plt.style.context(matplotlib_style):
fig, axs = plt.subplots(
num_rows,
num_cols,
figsize=(3 * num_cols, 3 * num_rows),
squeeze=False,
)
axs = axs.flatten()

df = pd.DataFrame()
if priors is not None:
for param, values in priors.items():
df_param = pd.DataFrame()
# flatten any chains if they leaked in
df_param["values"] = np.array(values).flatten()
df_param["type"] = "prior"
df_param["param"] = param
df = pd.concat([df, df_param], ignore_index=True, axis=0)
if posteriors is not None:
for param, values in posteriors.items():
df_param = pd.DataFrame()
# flatten any chains if they leaked in
df_param["values"] = np.array(values).flatten()
df_param["type"] = "posterior"
df_param["param"] = param
# this is necessary to make sure there are always two columns
# of violin plots, including when a posterior does not have an
# associated prior
if priors is not None and param not in priors.keys():
filler_row = pd.DataFrame(
{"values": [np.nan], "type": "prior", "param": param}
)
df_param = pd.concat(
[filler_row, df_param], ignore_index=True, axis=0
)
df = pd.concat([df, df_param], ignore_index=True, axis=0)

# parameters that share a first word will be colored the same for interpretability
unique_first_words = set(
[param.split("_")[0] for param in df["param"].unique()]
)
color_palette = sns.color_palette("Set2", n_colors=len(unique_first_words))
color_dict = dict(zip(unique_first_words, color_palette))

# Iterate over the parameters and create violin plots
for i, param in enumerate(df["param"].unique()):
ax: Axes = axs[i]
# if priors is not None and param in priors.keys():
# sns.violinplot(y=priors[param], ax=ax, alpha=0.5, label="prior")
sns.violinplot(
data=df.loc[df["param"] == param],
x="type",
y="values",
ax=ax,
color=color_dict[param.split("_")[0]],
)
ax.set_title(param)
ax.set_ylabel("")
ax.set_xlabel("")

# Remove empty subplots if necessary
if num_params < num_rows * num_cols:
for i in range(num_params, num_rows * num_cols):
fig.delaxes(axs[i])
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc="outside upper right")
fig.suptitle("Violin Plot of Parameters")
fig.tight_layout()
return fig
Loading