From baec05369826d3aca3df3927a62f780b8fea562a Mon Sep 17 00:00:00 2001
From: Xun Zhang <xunzh@amazon.com>
Date: Thu, 23 May 2024 14:46:08 -0700
Subject: [PATCH] fix memory CB bugs and upgrade UTs to compatible with core
 changes (#2469)

* fix memory CB bugs

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* change CB limit exception code to 429 and skip CB check for remote models

Signed-off-by: Xun Zhang <xunzh@amazon.com>

---------

Signed-off-by: Xun Zhang <xunzh@amazon.com>
(cherry picked from commit f88b6d60730afb71f3dce6d3fb65d5f5b085e7bb)
---
 .../TransportPredictionTaskAction.java        |  3 ++
 .../ml/breaker/MemoryCircuitBreaker.java      |  2 +-
 .../opensearch/ml/model/MLModelManager.java   |  5 ++-
 .../ml/task/MLPredictTaskRunner.java          |  2 +-
 .../org/opensearch/ml/task/MLTaskRunner.java  | 10 ++++-
 .../org/opensearch/ml/utils/MLNodeUtils.java  |  8 +++-
 .../TransportPredictionTaskActionTests.java   | 24 +++++++++++
 .../ml/breaker/MemoryCircuitBreakerTests.java | 18 ++++++++
 .../ml/model/MLModelCacheHelperTests.java     |  6 ++-
 .../ml/model/MLModelManagerTests.java         | 42 ++++++++++++-------
 .../ml/rest/RestMLMemoryCircuitBreakerIT.java |  2 +-
 .../ml/rest/RestMLRemoteInferenceIT.java      |  6 ++-
 .../ml/task/MLExecuteTaskRunnerTests.java     |  9 ++--
 .../opensearch/ml/task/TaskRunnerTests.java   |  5 +--
 14 files changed, 107 insertions(+), 35 deletions(-)

diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java
index 94ed36214a..4cf957c499 100644
--- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java
+++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java
@@ -20,6 +20,7 @@
 import org.opensearch.common.xcontent.XContentFactory;
 import org.opensearch.commons.authuser.User;
 import org.opensearch.core.action.ActionListener;
+import org.opensearch.core.common.breaker.CircuitBreakingException;
 import org.opensearch.core.rest.RestStatus;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.core.xcontent.ToXContent;
@@ -177,6 +178,8 @@ public void onResponse(MLModel mlModel) {
                                     );
                             } else if (e instanceof MLResourceNotFoundException) {
                                 wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
+                            } else if (e instanceof CircuitBreakingException) {
+                                wrappedListener.onFailure(e);
                             } else {
                                 wrappedListener
                                     .onFailure(
diff --git a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java
index 5e045ae539..c1287ef481 100644
--- a/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java
+++ b/plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java
@@ -50,6 +50,6 @@ public Short getThreshold() {
 
     @Override
     public boolean isOpen() {
-        return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
+        return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
     }
 }
diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
index 30cc0a0567..fd415828e6 100644
--- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
+++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java
@@ -59,6 +59,7 @@
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
 import java.util.concurrent.ConcurrentLinkedDeque;
@@ -834,7 +835,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
      * @param runningTaskLimit limit
      */
     public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
-        checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
+        if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
+            checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
+        }
         mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
     }
 
diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
index 101d9c9244..b341f4c9f5 100644
--- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
+++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java
@@ -143,7 +143,7 @@ public void dispatchTask(
                 if (clusterService.localNode().getId().equals(node.getId())) {
                     log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
                     request.setDispatchTask(false);
-                    executeTask(request, listener);
+                    checkCBAndExecute(functionName, request, listener);
                 } else {
                     log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
                     request.setDispatchTask(false);
diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java
index b2c71d6ed8..54195ab156 100644
--- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java
+++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java
@@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
     public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
         if (!request.isDispatchTask()) {
             log.debug("Run ML request {} locally", request.getRequestID());
-            checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
-            executeTask(request, listener);
+            checkCBAndExecute(functionName, request, listener);
             return;
         }
         dispatchTask(functionName, request, transportService, listener);
@@ -129,4 +128,11 @@ public void dispatchTask(
     protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);
 
     protected abstract void executeTask(Request request, ActionListener<Response> listener);
+
+    protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
+        if (functionName != FunctionName.REMOTE) {
+            checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
+        }
+        executeTask(request, listener);
+    }
 }
diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java
index 227518aabf..86fbfb1605 100644
--- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java
+++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java
@@ -18,12 +18,13 @@
 import org.opensearch.common.xcontent.LoggingDeprecationHandler;
 import org.opensearch.common.xcontent.XContentHelper;
 import org.opensearch.common.xcontent.XContentType;
+import org.opensearch.core.common.breaker.CircuitBreaker;
+import org.opensearch.core.common.breaker.CircuitBreakingException;
 import org.opensearch.core.common.bytes.BytesReference;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.core.xcontent.XContentParser;
 import org.opensearch.ml.breaker.MLCircuitBreakerService;
 import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
-import org.opensearch.ml.common.exception.MLLimitExceededException;
 import org.opensearch.ml.stats.MLNodeLevelStat;
 import org.opensearch.ml.stats.MLStats;
 
@@ -92,7 +93,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea
         ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
         if (openCircuitBreaker != null) {
             mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
-            throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!");
+            throw new CircuitBreakingException(
+                openCircuitBreaker.getName() + " is open, please check your resources!",
+                CircuitBreaker.Durability.TRANSIENT
+            );
         }
     }
 }
diff --git a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java
index aa7afdce6e..baaf2cec05 100644
--- a/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java
@@ -35,6 +35,8 @@
 import org.opensearch.common.util.concurrent.ThreadContext;
 import org.opensearch.commons.authuser.User;
 import org.opensearch.core.action.ActionListener;
+import org.opensearch.core.common.breaker.CircuitBreaker;
+import org.opensearch.core.common.breaker.CircuitBreakingException;
 import org.opensearch.core.rest.RestStatus;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.ml.common.FunctionName;
@@ -235,6 +237,28 @@ public void testPrediction_MLResourceNotFoundException() {
         assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage());
     }
 
+    public void testPrediction_MLLimitExceededException() {
+        when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
+        when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
+
+        doAnswer(invocation -> {
+            ActionListener<Boolean> listener = invocation.getArgument(3);
+            listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT));
+            return null;
+        }).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());
+
+        doAnswer(invocation -> {
+            ((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null);
+            return null;
+        }).when(mlPredictTaskRunner).run(any(), any(), any(), any());
+
+        transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener);
+
+        ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(CircuitBreakingException.class);
+        verify(actionListener).onFailure(argumentCaptor.capture());
+        assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage());
+    }
+
     public void testValidateInputSchemaSuccess() {
         RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet
             .builder()
diff --git a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java
index cdd1f6fc22..8c7f6f41d4 100644
--- a/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java
@@ -84,4 +84,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
         settingsService.applySettings(newSettingsBuilder.build());
         Assert.assertFalse(breaker.isOpen());
     }
+
+    @Test
+    public void testIsOpen_DisableMemoryCB() {
+        ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
+        settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
+        when(clusterService.getClusterSettings()).thenReturn(settingsService);
+
+        CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
+
+        when(mem.getHeapUsedPercent()).thenReturn((short) 90);
+        Assert.assertTrue(breaker.isOpen());
+
+        when(mem.getHeapUsedPercent()).thenReturn((short) 100);
+        Settings.Builder newSettingsBuilder = Settings.builder();
+        newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
+        settingsService.applySettings(newSettingsBuilder.build());
+        Assert.assertFalse(breaker.isOpen());
+    }
 }
diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java
index 232290520d..4b4e6ace27 100644
--- a/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java
@@ -24,6 +24,7 @@
 import org.junit.rules.ExpectedException;
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
+import org.opensearch.cluster.service.ClusterApplierService;
 import org.opensearch.cluster.service.ClusterService;
 import org.opensearch.common.settings.ClusterSettings;
 import org.opensearch.common.settings.Settings;
@@ -64,13 +65,16 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase {
     @Mock
     private TokenBucket rateLimiter;
 
+    @Mock
+    ClusterApplierService clusterApplierService;
+
     @Before
     public void setup() {
         MockitoAnnotations.openMocks(this);
         maxMonitoringRequests = 10;
         settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), maxMonitoringRequests).build();
         ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MONITORING_REQUEST_COUNT);
-        clusterService = spy(new ClusterService(settings, clusterSettings, null));
+        clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
 
         when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
         cacheHelper = new MLModelCacheHelper(clusterService, settings);
diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java
index d42fa9ca65..189ac01876 100644
--- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java
@@ -76,14 +76,16 @@
 import org.opensearch.action.update.UpdateRequest;
 import org.opensearch.action.update.UpdateResponse;
 import org.opensearch.client.Client;
+import org.opensearch.cluster.service.ClusterApplierService;
 import org.opensearch.cluster.service.ClusterService;
 import org.opensearch.common.settings.ClusterSettings;
 import org.opensearch.common.settings.Settings;
 import org.opensearch.common.util.concurrent.ThreadContext;
 import org.opensearch.core.action.ActionListener;
+import org.opensearch.core.common.breaker.CircuitBreaker;
+import org.opensearch.core.common.breaker.CircuitBreakingException;
 import org.opensearch.core.xcontent.NamedXContentRegistry;
 import org.opensearch.ml.breaker.MLCircuitBreakerService;
-import org.opensearch.ml.breaker.MemoryCircuitBreaker;
 import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
 import org.opensearch.ml.cluster.DiscoveryNodeHelper;
 import org.opensearch.ml.common.FunctionName;
@@ -114,7 +116,6 @@
 import org.opensearch.ml.stats.MLStats;
 import org.opensearch.ml.stats.suppliers.CounterSupplier;
 import org.opensearch.ml.task.MLTaskManager;
-import org.opensearch.monitor.jvm.JvmService;
 import org.opensearch.script.ScriptService;
 import org.opensearch.test.OpenSearchTestCase;
 import org.opensearch.threadpool.ThreadPool;
@@ -177,7 +178,7 @@ public class MLModelManagerTests extends OpenSearchTestCase {
     private ScriptService scriptService;
 
     @Mock
-    private MLTask pretrainedMLTask;
+    ClusterApplierService clusterApplierService;
 
     @Before
     public void setup() throws URISyntaxException {
@@ -196,7 +197,7 @@ public void setup() throws URISyntaxException {
             ML_COMMONS_MONITORING_REQUEST_COUNT,
             ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE
         );
-        clusterService = spy(new ClusterService(settings, clusterSettings, null));
+        clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
         xContentRegistry = NamedXContentRegistry.EMPTY;
 
         modelName = "model_name1";
@@ -323,7 +324,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
         when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
         when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
         when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
-        expectedEx.expect(MLException.class);
+        expectedEx.expect(CircuitBreakingException.class);
         expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
         modelManager.registerMLModel(registerModelInput, mlTask);
         verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
@@ -452,21 +453,30 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
         verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
     }
 
-    public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() {
+    public void testRegisterMLRemoteModel_SkipMemoryCBOpen() {
         ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
-        MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class));
-        String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!";
-        when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage));
-
+        doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
+        when(mlCircuitBreakerService.checkOpenCB())
+            .thenThrow(
+                new CircuitBreakingException(
+                    "Memory Circuit Breaker is open, please check your resources!",
+                    CircuitBreaker.Durability.TRANSIENT
+                )
+            );
+        when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
+        when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
         MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
         MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
+        mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
+        doAnswer(invocation -> {
+            ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1];
+            indexResponseActionListener.onResponse(indexResponse);
+            return null;
+        }).when(client).index(any(), any());
+        when(indexResponse.getId()).thenReturn("mockIndexId");
         modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
-
-        ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
-        verify(listener, times(1)).onFailure(argCaptor.capture());
-        Exception e = argCaptor.getValue();
-        assertTrue(e instanceof MLLimitExceededException);
-        assertEquals(memCBIsOpenMessage, e.getMessage());
+        assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE);
+        verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
     }
 
     public void testIndexRemoteModel() throws PrivilegedActionException {
diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java
index d1d332050e..dcaa2610b7 100644
--- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java
+++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java
@@ -58,7 +58,7 @@ public void testRunWithMemoryCircuitBreaker() throws IOException {
                 exception.getMessage(),
                 allOf(
                     containsString("Memory Circuit Breaker is open, please check your resources!"),
-                    containsString("m_l_limit_exceeded_exception")
+                    containsString("circuit_breaking_exception")
                 )
             );
 
diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java
index 4da3596e03..3e4db6d255 100644
--- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java
+++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java
@@ -199,7 +199,7 @@ public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, In
         Response response = createConnector(completionModelConnectorEntity);
         Map responseMap = parseResponseToMap(response);
         String connectorId = (String) responseMap.get("connector_id");
-        response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1);
+        response = registerRemoteModelWithTTLAndSkipHeapMemCheck("openAI-GPT-3.5 completions", connectorId, 1);
         responseMap = parseResponseToMap(response);
         String modelId = (String) responseMap.get("model_id");
         String predictInput = "{\n" + "  \"parameters\": {\n" + "      \"prompt\": \"Say this is a test\"\n" + "  }\n" + "}";
@@ -814,11 +814,13 @@ public static Response registerRemoteModel(String name, String connectorId) thro
             .makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
     }
 
-    public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException {
+    public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name, String connectorId, int ttl) throws IOException {
         String registerModelGroupEntity = "{\n"
             + "  \"name\": \"remote_model_group\",\n"
             + "  \"description\": \"This is an example description\"\n"
             + "}";
+        String updateJVMHeapThreshold = "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":0}}";
+        TestHelper.makeRequest(client(), "PUT", "/_cluster/settings", null, TestHelper.toHttpEntity(updateJVMHeapThreshold), null);
         Response response = TestHelper
             .makeRequest(
                 client(),
diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java
index 9011746797..a4e7a87a82 100644
--- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java
@@ -28,6 +28,7 @@
 import org.mockito.Mock;
 import org.mockito.MockitoAnnotations;
 import org.opensearch.client.Client;
+import org.opensearch.cluster.service.ClusterApplierService;
 import org.opensearch.cluster.service.ClusterService;
 import org.opensearch.common.settings.ClusterSettings;
 import org.opensearch.common.settings.Settings;
@@ -48,7 +49,6 @@
 import org.opensearch.ml.stats.suppliers.CounterSupplier;
 import org.opensearch.test.OpenSearchTestCase;
 import org.opensearch.threadpool.ThreadPool;
-import org.opensearch.transport.TransportService;
 
 public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
 
@@ -70,13 +70,12 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
     @Mock
     MLCircuitBreakerService mlCircuitBreakerService;
 
-    @Mock
-    TransportService transportService;
-
     @Mock
     ActionListener<MLExecuteTaskResponse> listener;
     @Mock
     DiscoveryNodeHelper nodeHelper;
+    @Mock
+    ClusterApplierService clusterApplierService;
 
     @Rule
     public ExpectedException exceptionRule = ExpectedException.none();
@@ -115,7 +114,7 @@ public void setup() {
             ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE,
             ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL
         );
-        clusterService = spy(new ClusterService(settings, clusterSettings, null));
+        clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
         when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
 
         Map<Enum, MLStat<?>> stats = new ConcurrentHashMap<>();
diff --git a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java
index 9e2abccebb..84d5e5af0e 100644
--- a/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java
+++ b/plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java
@@ -34,7 +34,6 @@
 import org.opensearch.ml.common.MLTask;
 import org.opensearch.ml.common.MLTaskState;
 import org.opensearch.ml.common.MLTaskType;
-import org.opensearch.ml.common.exception.MLLimitExceededException;
 import org.opensearch.ml.common.transport.MLTaskRequest;
 import org.opensearch.ml.stats.MLNodeLevelStat;
 import org.opensearch.ml.stats.MLStat;
@@ -139,8 +138,8 @@ public void testRun_CircuitBreakerOpen() {
         TransportService transportService = mock(TransportService.class);
         ActionListener listener = mock(ActionListener.class);
         MLTaskRequest request = new MLTaskRequest(false);
-        expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener));
+        mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener);
         Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
-        assertEquals(1L, value.longValue());
+        assertEquals(0L, value.longValue());
     }
 }