From e18dcba14ec7eaac7927db9e6c6d93e696f5f481 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Fri, 17 Jan 2025 20:45:00 +0000 Subject: [PATCH] hotfix to flatten_list_parameters not working with jax array, adding tests --- src/dynode/utils.py | 6 ++--- tests/test_utils.py | 55 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 3 deletions(-) diff --git a/src/dynode/utils.py b/src/dynode/utils.py index 5bb1d5dd..4bdd55ff 100644 --- a/src/dynode/utils.py +++ b/src/dynode/utils.py @@ -1121,14 +1121,14 @@ def drop_sample_chains(samples: dict, dropped_chain_vals: list): def flatten_list_parameters( - samples: dict[str, np.ndarray], + samples: dict[str, np.ndarray | Array], ) -> dict[str, np.ndarray]: """ Flatten plated parameters into separate keys in the samples dictionary. Parameters ---------- - samples : dict[str, np.ndarray] + samples : dict[str, np.ndarray | Array] Dictionary with parameter names as keys and sample arrays as values. Arrays may have shape MxNxP for P independent draws. @@ -1144,7 +1144,7 @@ def flatten_list_parameters( """ return_dict = {} for key, value in samples.items(): - if isinstance(value, np.ndarray) and value.ndim > 2: + if isinstance(value, (np.ndarray, Array)) and value.ndim > 2: num_dims = value.ndim - 2 indices = ( np.indices(value.shape[-num_dims:]).reshape(num_dims, -1).T diff --git a/tests/test_utils.py b/tests/test_utils.py index 32a548a4..24a45a74 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ from enum import IntEnum import jax.numpy as jnp +import numpy as np import numpyro.distributions as dist from dynode import utils @@ -426,3 +427,57 @@ def test_get_timeline_from_solution_with_command_compartment_slice(): assert jnp.all( timeline == 16 ) # Each element in sol is 1, summed over 4*1*1*4 = 16 + + +def test_flatten_list_params_numpy(): + # simulate 4 chains and 20 samples each with 4 plated parameters + testing = {"test": np.ones((4, 20, 5))} + flattened = utils.flatten_list_parameters(testing) + assert "test" not in flattened.keys() + for suffix in range(5): + key = "test_%s" % str(suffix) + assert ( + key in flattened.keys() + ), "flatten_list_parameters not naming split params correctly." + assert flattened[key].shape == ( + 4, + 20, + ), "flatten_list_parameters breaking up wrong axis" + + +def test_flatten_list_params_jax_numpy(): + # simulate 4 chains and 20 samples each with 4 plated parameters + # this time with jax numpy instead of numpy + testing = {"test": jnp.ones((4, 20, 5))} + flattened = utils.flatten_list_parameters(testing) + assert "test" not in flattened.keys() + for suffix in range(5): + key = "test_%s" % str(suffix) + assert ( + key in flattened.keys() + ), "flatten_list_parameters not naming split params correctly." + assert flattened[key].shape == ( + 4, + 20, + ), "flatten_list_parameters breaking up wrong axis" + + +def test_flatten_list_params_multi_dim(): + # simulate 4 chains and 20 samples each with 10 plated parameters + # this time with jax numpy instead of numpy + testing = {"test": jnp.ones((4, 20, 5, 2))} + flattened = utils.flatten_list_parameters(testing) + assert "test" not in flattened.keys() + for suffix_first_dim in range(5): + for suffix_second_dim in range(2): + key = "test_%s_%s" % ( + str(suffix_first_dim), + str(suffix_second_dim), + ) + assert ( + key in flattened.keys() + ), "flatten_list_parameters not naming split params correctly." + assert flattened[key].shape == ( + 4, + 20, + ), "flatten_list_parameters breaking up wrong axis when passed >3"