Skip to content

Commit

Permalink
fixing mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Jan 17, 2025
1 parent e18dcba commit 21bc746
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/dynode/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,7 +1122,7 @@ def drop_sample_chains(samples: dict, dropped_chain_vals: list):

def flatten_list_parameters(
samples: dict[str, np.ndarray | Array],
) -> dict[str, np.ndarray]:
) -> dict[str, np.ndarray | Array]:
"""
Flatten plated parameters into separate keys in the samples dictionary.
Expand Down
4 changes: 2 additions & 2 deletions src/dynode/vis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def plot_checkpoint_inference_correlation_pairs(
Figure with n rows and n columns where n is the number of sampled parameters.
"""
# convert lists to np.arrays
posteriors: dict[str, np.ndarray] = flatten_list_parameters(
posteriors: dict[str, np.ndarray | Array] = flatten_list_parameters(
{
key: np.array(val) if isinstance(val, list) else val
for key, val in posteriors_in.items()
Expand Down Expand Up @@ -408,7 +408,7 @@ def plot_mcmc_chains(
Matplotlib figure containing the plots.
"""
# Determine the number of parameters and chains
samples: dict[str, np.ndarray] = flatten_list_parameters(
samples: dict[str, np.ndarray | Array] = flatten_list_parameters(
{
key: np.array(val) if isinstance(val, list) else val
for key, val in samples_in.items()
Expand Down

0 comments on commit 21bc746

Please sign in to comment.