Skip to content

Commit 8e47abc

Browse files
committed
exclude remote models in circuit breaker checks and fix memory CB bugs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 478886c commit 8e47abc

11 files changed

+92
-21
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import org.opensearch.common.util.concurrent.ThreadContext;
2020
import org.opensearch.commons.authuser.User;
2121
import org.opensearch.core.action.ActionListener;
22+
import org.opensearch.core.common.breaker.CircuitBreakingException;
2223
import org.opensearch.core.rest.RestStatus;
2324
import org.opensearch.core.xcontent.NamedXContentRegistry;
2425
import org.opensearch.ml.common.FunctionName;
@@ -171,6 +172,8 @@ public void onResponse(MLModel mlModel) {
171172
);
172173
} else if (e instanceof MLResourceNotFoundException) {
173174
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
175+
} else if (e instanceof CircuitBreakingException) {
176+
wrappedListener.onFailure(e);
174177
} else {
175178
wrappedListener
176179
.onFailure(

plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ public Short getThreshold() {
5050

5151
@Override
5252
public boolean isOpen() {
53-
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
53+
return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
5454
}
5555
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import java.util.HashMap;
6060
import java.util.List;
6161
import java.util.Map;
62+
import java.util.Objects;
6263
import java.util.Optional;
6364
import java.util.Set;
6465
import java.util.concurrent.ConcurrentLinkedDeque;
@@ -827,7 +828,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
827828
* @param runningTaskLimit limit
828829
*/
829830
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
830-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
831+
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
832+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
833+
}
831834
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
832835
}
833836

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ public void dispatchTask(
137137
if (clusterService.localNode().getId().equals(node.getId())) {
138138
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
139139
request.setDispatchTask(false);
140-
executeTask(request, listener);
140+
checkCBAndExecute(functionName, request, listener);
141141
} else {
142142
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
143143
request.setDispatchTask(false);

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
8787
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
8888
if (!request.isDispatchTask()) {
8989
log.debug("Run ML request {} locally", request.getRequestID());
90-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
91-
executeTask(request, listener);
90+
checkCBAndExecute(functionName, request, listener);
9291
return;
9392
}
9493
dispatchTask(functionName, request, transportService, listener);
@@ -129,4 +128,11 @@ public void dispatchTask(
129128
protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);
130129

131130
protected abstract void executeTask(Request request, ActionListener<Response> listener);
131+
132+
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
133+
if (functionName != FunctionName.REMOTE) {
134+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
135+
}
136+
executeTask(request, listener);
137+
}
132138
}

plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
1717
import org.opensearch.common.xcontent.XContentHelper;
1818
import org.opensearch.common.xcontent.XContentType;
19+
import org.opensearch.core.common.breaker.CircuitBreaker;
20+
import org.opensearch.core.common.breaker.CircuitBreakingException;
1921
import org.opensearch.core.common.bytes.BytesReference;
2022
import org.opensearch.core.xcontent.NamedXContentRegistry;
2123
import org.opensearch.core.xcontent.XContentParser;
2224
import org.opensearch.ml.breaker.MLCircuitBreakerService;
2325
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
24-
import org.opensearch.ml.common.exception.MLLimitExceededException;
2526
import org.opensearch.ml.stats.MLNodeLevelStat;
2627
import org.opensearch.ml.stats.MLStats;
2728

@@ -60,7 +61,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea
6061
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
6162
if (openCircuitBreaker != null) {
6263
mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
63-
throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!");
64+
throw new CircuitBreakingException(
65+
openCircuitBreaker.getName() + " is open, please check your resources!",
66+
CircuitBreaker.Durability.TRANSIENT
67+
);
6468
}
6569
}
6670
}

plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java

+24
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
import org.opensearch.common.util.concurrent.ThreadContext;
3535
import org.opensearch.commons.authuser.User;
3636
import org.opensearch.core.action.ActionListener;
37+
import org.opensearch.core.common.breaker.CircuitBreaker;
38+
import org.opensearch.core.common.breaker.CircuitBreakingException;
3739
import org.opensearch.core.rest.RestStatus;
3840
import org.opensearch.core.xcontent.NamedXContentRegistry;
3941
import org.opensearch.ml.common.FunctionName;
@@ -233,4 +235,26 @@ public void testPrediction_MLResourceNotFoundException() {
233235
assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage());
234236
}
235237

238+
public void testPrediction_MLLimitExceededException() {
239+
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
240+
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
241+
242+
doAnswer(invocation -> {
243+
ActionListener<Boolean> listener = invocation.getArgument(3);
244+
listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT));
245+
return null;
246+
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());
247+
248+
doAnswer(invocation -> {
249+
((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null);
250+
return null;
251+
}).when(mlPredictTaskRunner).run(any(), any(), any(), any());
252+
253+
transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener);
254+
255+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(CircuitBreakingException.class);
256+
verify(actionListener).onFailure(argumentCaptor.capture());
257+
assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage());
258+
}
259+
236260
}

plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java

+18
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
8484
settingsService.applySettings(newSettingsBuilder.build());
8585
Assert.assertFalse(breaker.isOpen());
8686
}
87+
88+
@Test
89+
public void testIsOpen_DisableMemoryCB() {
90+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
91+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
92+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
93+
94+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
95+
96+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
97+
Assert.assertTrue(breaker.isOpen());
98+
99+
when(mem.getHeapUsedPercent()).thenReturn((short) 100);
100+
Settings.Builder newSettingsBuilder = Settings.builder();
101+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
102+
settingsService.applySettings(newSettingsBuilder.build());
103+
Assert.assertFalse(breaker.isOpen());
104+
}
87105
}

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+23-10
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
import org.opensearch.common.settings.Settings;
8181
import org.opensearch.common.util.concurrent.ThreadContext;
8282
import org.opensearch.core.action.ActionListener;
83+
import org.opensearch.core.common.breaker.CircuitBreaker;
84+
import org.opensearch.core.common.breaker.CircuitBreakingException;
8385
import org.opensearch.core.xcontent.NamedXContentRegistry;
8486
import org.opensearch.ml.breaker.MLCircuitBreakerService;
8587
import org.opensearch.ml.breaker.MemoryCircuitBreaker;
@@ -322,7 +324,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
322324
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
323325
when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
324326
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
325-
expectedEx.expect(MLException.class);
327+
expectedEx.expect(CircuitBreakingException.class);
326328
expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
327329
modelManager.registerMLModel(registerModelInput, mlTask);
328330
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
@@ -451,21 +453,32 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
451453
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
452454
}
453455

454-
public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() {
456+
public void testRegisterMLRemoteModel_SkipMemoryCBOpen() {
455457
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
456-
MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class));
457-
String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!";
458-
when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage));
458+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
459+
when(mlCircuitBreakerService.checkOpenCB())
460+
.thenThrow(
461+
new CircuitBreakingException(
462+
"Memory Circuit Breaker is open, please check your resources!",
463+
CircuitBreaker.Durability.TRANSIENT
464+
)
465+
);
466+
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
467+
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
459468

460469
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
461470
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
471+
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
472+
doAnswer(invocation -> {
473+
ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1];
474+
indexResponseActionListener.onResponse(indexResponse);
475+
return null;
476+
}).when(client).index(any(), any());
477+
when(indexResponse.getId()).thenReturn("mockIndexId");
462478
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
463479

464-
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
465-
verify(listener, times(1)).onFailure(argCaptor.capture());
466-
Exception e = argCaptor.getValue();
467-
assertTrue(e instanceof MLLimitExceededException);
468-
assertEquals(memCBIsOpenMessage, e.getMessage());
480+
assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE);
481+
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
469482
}
470483

471484
public void testIndexRemoteModel() throws PrivilegedActionException {

plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public void testRunWithMemoryCircuitBreaker() throws IOException {
5858
exception.getMessage(),
5959
allOf(
6060
containsString("Memory Circuit Breaker is open, please check your resources!"),
61-
containsString("m_l_limit_exceeded_exception")
61+
containsString("circuit_breaking_exception")
6262
)
6363
);
6464

plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,15 @@ public void testHandleAsyncMLTaskComplete_SyncTask() {
132132
verify(mlTaskManager, never()).updateMLTask(eq(syncMlTask.getTaskId()), any(), anyLong(), anyBoolean());
133133
}
134134

135-
public void testRun_CircuitBreakerOpen() {
135+
public void testRemoteInferenceRun_CircuitBreakerNotOpen() {
136136
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
137137
when(thresholdCircuitBreaker.getName()).thenReturn("Memory Circuit Breaker");
138138
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
139139
TransportService transportService = mock(TransportService.class);
140140
ActionListener listener = mock(ActionListener.class);
141141
MLTaskRequest request = new MLTaskRequest(false);
142-
expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener));
142+
mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener);
143143
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
144-
assertEquals(1L, value.longValue());
144+
assertEquals(0L, value.longValue());
145145
}
146146
}

0 commit comments

Comments
 (0)