diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f7907b311..b3f172cbb5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289) * Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307) * Refactor integ tests that access model index [#1423](https://github.com/opensearch-project/k-NN/pull/1423) +* Fix flaky model tests [#1429](https://github.com/opensearch-project/k-NN/pull/1429) ### Documentation ### Maintenance * Update developer guide to include M1 Setup [#1222](https://github.com/opensearch-project/k-NN/pull/1222) diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 1297dc1845..2af8df953e 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -14,9 +14,11 @@ import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; +import org.mockito.MockedStatic; import org.opensearch.ExceptionsHelper; import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.ResourceNotFoundException; +import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.core.action.ActionListener; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.StepListener; @@ -46,6 +48,7 @@ import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; import org.opensearch.core.rest.RestStatus; +import org.opensearch.knn.training.TrainingJobClusterStateListener; import java.io.IOException; import java.time.ZoneOffset; @@ -57,6 +60,10 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; import static org.opensearch.cluster.metadata.Metadata.builder; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; @@ -73,15 +80,22 @@ public class ModelDaoTests extends KNNSingleNodeTestCase { private static ExecutorService modelGetterExecutor; private static final String FAILED = "failed"; + private static MockedStatic<TrainingJobClusterStateListener> trainingJobClusterStateListenerMockedStatic; @BeforeClass public static void setup() { modelGetterExecutor = Executors.newSingleThreadExecutor(); + trainingJobClusterStateListenerMockedStatic = mockStatic(TrainingJobClusterStateListener.class); + final TrainingJobClusterStateListener trainingJobClusterStateListener = mock(TrainingJobClusterStateListener.class); + doNothing().when(trainingJobClusterStateListener).clusterChanged(any(ClusterChangedEvent.class)); + trainingJobClusterStateListenerMockedStatic.when(TrainingJobClusterStateListener::getInstance) + .thenReturn(trainingJobClusterStateListener); } @AfterClass public static void teardown() { modelGetterExecutor.shutdown(); + trainingJobClusterStateListenerMockedStatic.close(); } public void testCreate() throws IOException, InterruptedException {