Skip to content

Commit 92dae99

Browse files
authored
Merge pull request #1057 from dstl/akkf
Fix failing test for AKKF proposal multivariate sampling to compare mean of particle state
2 parents c9bc0b8 + c5a84be commit 92dae99

File tree

4 files changed

+7
-7
lines changed

4 files changed

+7
-7
lines changed

stonesoup/predictor/kernel.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313

1414
class AdaptiveKernelKalmanPredictor(KalmanPredictor):
15-
r"""An implementation of the adaptive kernel Kalman filter (AKKF) predictor. Here, the AKKF draws
16-
inspiration from the concepts of kernel mean embeddings (KME) and Kalman Filter to address
17-
tracking problems in nonlinear systems.
15+
r"""An implementation of the adaptive kernel Kalman filter (AKKF) predictor. Here, the AKKF
16+
draws inspiration from the concepts of kernel mean embeddings (KME) and Kalman Filter to
17+
address tracking problems in nonlinear systems.
1818
1919
In the state space, at time :math:`k`, the prior state
2020
particles are generated by passing the proposal particles at time :math:`k-1`, i.e.,

stonesoup/predictor/tests/test_kernel.py

-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from ...types.update import KernelParticleStateUpdate
1414

1515
number_particles = 4
16-
np.random.seed(50)
1716
timestamp = datetime.datetime.now()
1817
samples = multivariate_normal.rvs([0, 0, 0, 0],
1918
np.diag([0.01, 0.005, 0.1, 0.5])**2,

stonesoup/types/tests/test_state.py

-1
Original file line numberDiff line numberDiff line change
@@ -863,7 +863,6 @@ def test_asd_weighted_gaussian_state():
863863
def test_kernel_particle_state():
864864
number_particles = 5
865865
weights = np.array([1 / number_particles] * number_particles)
866-
np.random.seed(50)
867866

868867
samples = multivariate_normal.rvs([0, 0, 0, 0],
869868
np.diag([0.01, 0.005, 0.1, 0.5]) ** 2,

stonesoup/updater/tests/test_kernel.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
new_timestamp = timestamp + time_diff
2020

2121
number_particles = 5
22-
np.random.seed(50)
2322
samples = multivariate_normal.rvs([0, 0, 0, 0],
2423
np.diag([0.01, 0.005, 0.1, 0.5])**2,
2524
size=number_particles)
@@ -107,7 +106,10 @@ def test_kernel_updater(kernel, measurement_model, c, ialpha):
107106
assert measurement.timestamp == gt_state.timestamp
108107
assert update.hypothesis.measurement.timestamp == gt_state.timestamp
109108
assert np.allclose(update.state_vector, prediction.state_vector)
110-
assert np.allclose(update.proposal, StateVectors(new_state_vector.T), atol=1e0)
109+
assert np.allclose(
110+
np.mean(update.proposal, axis=1),
111+
np.mean(StateVectors(new_state_vector.T), axis=1),
112+
atol=2)
111113
assert np.allclose(update.weight, updated_weights.ravel())
112114
assert np.allclose(update.kernel_covar, updated_covariance)
113115

0 commit comments

Comments
 (0)