From 232eb70465160a4666195f7bf0077247342ae5cf Mon Sep 17 00:00:00 2001 From: James Wright Date: Wed, 20 Nov 2024 14:54:12 +0000 Subject: [PATCH] Add kernel covar shape tests --- stonesoup/tests/test_kernel.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/stonesoup/tests/test_kernel.py b/stonesoup/tests/test_kernel.py index 02c07ff37..5fc7223c9 100644 --- a/stonesoup/tests/test_kernel.py +++ b/stonesoup/tests/test_kernel.py @@ -161,8 +161,15 @@ "quartic_prior_as_state_vector", "quartic_proposal_as_state_vector"] ) def test_kernel(kernel_class, output, state1, state2): - kernel = kernel_class(state1, state2) - assert np.allclose(kernel, output) + kernel_covar = kernel_class(state1, state2) + sv1 = state1.state_vector if isinstance(state1, State) else state1 + print(state2, state2 is None) + if state2 is not None: + sv2 = state2.state_vector if isinstance(state2, State) else state2 + else: + sv2 = sv1 + assert kernel_covar.shape == (sv1.shape[1], sv2.shape[1]) + assert np.allclose(kernel_covar, output) def test_not_implemented(): @@ -186,6 +193,7 @@ def test_multiplicative_kernel(power): multiplicative_covar = multiplicative_kernel(state1, state2) polynomial_covar = polynomial_kernel(state1, state2) + assert multiplicative_covar.shape == (state1.shape[1], state2.shape[1]) assert np.allclose(linear_covar, multiplicative_covar) assert np.allclose(multiplicative_covar, polynomial_covar) assert np.allclose(linear_covar, polynomial_covar) @@ -220,8 +228,9 @@ def test_additive_kernel(kernel_class): kernel_list = [kernel] * 2 additive_kernel = AdditiveKernel(kernel_list=kernel_list) state1 = StateVectors([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) - state2 = StateVectors([[2, 2, 2], [3, 3, 3], [4, 4, 4], [5, 5, 5]]) + state2 = StateVectors([[2, 2], [3, 3], [4, 4], [5, 5]]) linear_covar = kernel(state1, state2) + assert linear_covar.shape == (state1.shape[1], state2.shape[1]) assert np.allclose(linear_covar + linear_covar, additive_kernel(state1, state2))