Skip to content

Commit 96a6404

Browse files
committed
exclude remote models in circuit breaker checks and fix memory CB bugs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 0c27efc commit 96a6404

File tree

10 files changed

+235
-13
lines changed

10 files changed

+235
-13
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+25-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
package org.opensearch.ml.action.prediction;
77

8+
import org.opensearch.OpenSearchStatusException;
89
import org.opensearch.action.ActionRequest;
910
import org.opensearch.action.support.ActionFilters;
1011
import org.opensearch.action.support.HandledTransportAction;
@@ -14,9 +15,12 @@
1415
import org.opensearch.common.util.concurrent.ThreadContext;
1516
import org.opensearch.commons.authuser.User;
1617
import org.opensearch.core.action.ActionListener;
18+
import org.opensearch.core.common.breaker.CircuitBreakingException;
19+
import org.opensearch.core.rest.RestStatus;
1720
import org.opensearch.core.xcontent.NamedXContentRegistry;
1821
import org.opensearch.ml.common.FunctionName;
1922
import org.opensearch.ml.common.MLModel;
23+
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
2024
import org.opensearch.ml.common.exception.MLValidationException;
2125
import org.opensearch.ml.common.transport.MLTaskResponse;
2226
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
@@ -108,7 +112,27 @@ public void onResponse(MLModel mlModel) {
108112
}
109113
}, e -> {
110114
log.error("Failed to Validate Access for ModelId " + modelId, e);
111-
wrappedListener.onFailure(e);
115+
if (e instanceof OpenSearchStatusException) {
116+
wrappedListener
117+
.onFailure(
118+
new OpenSearchStatusException(
119+
e.getMessage(),
120+
RestStatus.fromCode(((OpenSearchStatusException) e).status().getStatus())
121+
)
122+
);
123+
} else if (e instanceof MLResourceNotFoundException) {
124+
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
125+
} else if (e instanceof CircuitBreakingException) {
126+
wrappedListener.onFailure(e);
127+
} else {
128+
wrappedListener
129+
.onFailure(
130+
new OpenSearchStatusException(
131+
"Failed to Validate Access for ModelId " + modelId,
132+
RestStatus.FORBIDDEN
133+
)
134+
);
135+
}
112136
}));
113137
}
114138

plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,6 @@ public String getName() {
4545

4646
@Override
4747
public boolean isOpen() {
48-
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
48+
return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
4949
}
5050
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import java.util.HashMap;
5454
import java.util.List;
5555
import java.util.Map;
56+
import java.util.Objects;
5657
import java.util.Optional;
5758
import java.util.Set;
5859
import java.util.concurrent.ConcurrentLinkedDeque;
@@ -781,7 +782,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
781782
* @param runningTaskLimit limit
782783
*/
783784
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
784-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
785+
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
786+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
787+
}
785788
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
786789
}
787790

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ public void dispatchTask(
124124
if (clusterService.localNode().getId().equals(node.getId())) {
125125
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
126126
request.setDispatchTask(false);
127-
executeTask(request, listener);
127+
checkCBAndExecute(functionName, request, listener);
128128
} else {
129129
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
130130
request.setDispatchTask(false);

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
8787
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
8888
if (!request.isDispatchTask()) {
8989
log.debug("Run ML request {} locally", request.getRequestID());
90-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
91-
executeTask(request, listener);
90+
checkCBAndExecute(functionName, request, listener);
9291
return;
9392
}
9493
dispatchTask(functionName, request, transportService, listener);
@@ -129,4 +128,11 @@ public void dispatchTask(
129128
protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);
130129

131130
protected abstract void executeTask(Request request, ActionListener<Response> listener);
131+
132+
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
133+
if (functionName != FunctionName.REMOTE) {
134+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
135+
}
136+
executeTask(request, listener);
137+
}
132138
}

plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
1717
import org.opensearch.common.xcontent.XContentHelper;
1818
import org.opensearch.common.xcontent.XContentType;
19+
import org.opensearch.core.common.breaker.CircuitBreaker;
20+
import org.opensearch.core.common.breaker.CircuitBreakingException;
1921
import org.opensearch.core.common.bytes.BytesReference;
2022
import org.opensearch.core.xcontent.NamedXContentRegistry;
2123
import org.opensearch.core.xcontent.XContentParser;
2224
import org.opensearch.ml.breaker.MLCircuitBreakerService;
2325
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
24-
import org.opensearch.ml.common.exception.MLLimitExceededException;
2526
import org.opensearch.ml.stats.MLNodeLevelStat;
2627
import org.opensearch.ml.stats.MLStats;
2728

@@ -60,7 +61,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea
6061
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
6162
if (openCircuitBreaker != null) {
6263
mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
63-
throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!");
64+
throw new CircuitBreakingException(
65+
openCircuitBreaker.getName() + " is open, please check your resources!",
66+
CircuitBreaker.Durability.TRANSIENT
67+
);
6468
}
6569
}
6670
}

plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java

+42
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,16 @@
66
package org.opensearch.ml.breaker;
77

88
import static org.mockito.Mockito.when;
9+
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_JVM_HEAP_MEM_THRESHOLD;
910

1011
import org.junit.Assert;
1112
import org.junit.Before;
1213
import org.junit.Test;
1314
import org.mockito.Mock;
1415
import org.mockito.MockitoAnnotations;
16+
import org.opensearch.cluster.service.ClusterService;
17+
import org.opensearch.common.settings.ClusterSettings;
18+
import org.opensearch.common.settings.Settings;
1519
import org.opensearch.monitor.jvm.JvmService;
1620
import org.opensearch.monitor.jvm.JvmStats;
1721

@@ -26,6 +30,9 @@ public class MemoryCircuitBreakerTests {
2630
@Mock
2731
JvmStats.Mem mem;
2832

33+
@Mock
34+
ClusterService clusterService;
35+
2936
@Before
3037
public void setup() {
3138
MockitoAnnotations.openMocks(this);
@@ -60,4 +67,39 @@ public void testIsOpen_CustomThreshold_ExceedMemoryThreshold() {
6067
when(mem.getHeapUsedPercent()).thenReturn((short) 95);
6168
Assert.assertTrue(breaker.isOpen());
6269
}
70+
71+
@Test
72+
public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
73+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
74+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
75+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
76+
77+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
78+
79+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
80+
Assert.assertTrue(breaker.isOpen());
81+
82+
Settings.Builder newSettingsBuilder = Settings.builder();
83+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 95);
84+
settingsService.applySettings(newSettingsBuilder.build());
85+
Assert.assertFalse(breaker.isOpen());
86+
}
87+
88+
@Test
89+
public void testIsOpen_DisableMemoryCB() {
90+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
91+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
92+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
93+
94+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
95+
96+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
97+
Assert.assertTrue(breaker.isOpen());
98+
99+
when(mem.getHeapUsedPercent()).thenReturn((short) 100);
100+
Settings.Builder newSettingsBuilder = Settings.builder();
101+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
102+
settingsService.applySettings(newSettingsBuilder.build());
103+
Assert.assertFalse(breaker.isOpen());
104+
}
63105
}

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+65-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@
8080
import org.opensearch.common.settings.Settings;
8181
import org.opensearch.common.util.concurrent.ThreadContext;
8282
import org.opensearch.core.action.ActionListener;
83+
import org.opensearch.core.common.breaker.CircuitBreaker;
84+
import org.opensearch.core.common.breaker.CircuitBreakingException;
8385
import org.opensearch.core.xcontent.NamedXContentRegistry;
8486
import org.opensearch.ml.breaker.MLCircuitBreakerService;
8587
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
@@ -98,6 +100,7 @@
98100
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
99101
import org.opensearch.ml.common.transport.deploy.MLDeployModelAction;
100102
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
103+
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
101104
import org.opensearch.ml.common.transport.upload_chunk.MLRegisterModelMetaInput;
102105
import org.opensearch.ml.engine.MLEngine;
103106
import org.opensearch.ml.engine.ModelHelper;
@@ -318,7 +321,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
318321
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
319322
when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
320323
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
321-
expectedEx.expect(MLException.class);
324+
expectedEx.expect(CircuitBreakingException.class);
322325
expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
323326
modelManager.registerMLModel(registerModelInput, mlTask);
324327
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
@@ -409,6 +412,55 @@ public void testRegisterMLModel_RegisterPreBuildModel() throws PrivilegedActionE
409412
);
410413
}
411414

415+
public void testRegisterMLRemoteModel() throws PrivilegedActionException {
416+
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
417+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
418+
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(null);
419+
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
420+
when(modelHelper.downloadPrebuiltModelMetaList(any(), any())).thenReturn(Collections.singletonList("demo"));
421+
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
422+
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
423+
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
424+
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
425+
doAnswer(invocation -> {
426+
ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1];
427+
indexResponseActionListener.onResponse(indexResponse);
428+
return null;
429+
}).when(client).index(any(), any());
430+
when(indexResponse.getId()).thenReturn("mockIndexId");
431+
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
432+
assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE);
433+
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
434+
}
435+
436+
public void testRegisterMLRemoteModel_SkipMemoryCBOpen() {
437+
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
438+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
439+
when(mlCircuitBreakerService.checkOpenCB())
440+
.thenThrow(
441+
new CircuitBreakingException(
442+
"Memory Circuit Breaker is open, please check your resources!",
443+
CircuitBreaker.Durability.TRANSIENT
444+
)
445+
);
446+
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
447+
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
448+
449+
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
450+
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
451+
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
452+
doAnswer(invocation -> {
453+
ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1];
454+
indexResponseActionListener.onResponse(indexResponse);
455+
return null;
456+
}).when(client).index(any(), any());
457+
when(indexResponse.getId()).thenReturn("mockIndexId");
458+
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
459+
460+
assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE);
461+
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
462+
}
463+
412464
@Ignore
413465
public void testRegisterMLModel_DownloadModelFile() throws IOException {
414466
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
@@ -963,4 +1015,16 @@ private MLRegisterModelInput mockPretrainedInput() {
9631015
.functionName(FunctionName.SPARSE_ENCODING)
9641016
.build();
9651017
}
1018+
1019+
private MLRegisterModelInput mockRemoteModelInput(boolean isHidden) {
1020+
return MLRegisterModelInput
1021+
.builder()
1022+
.modelName(modelName)
1023+
.version(version)
1024+
.modelGroupId("modelGroupId")
1025+
.modelFormat(modelFormat)
1026+
.functionName(FunctionName.REMOTE)
1027+
.deployModel(true)
1028+
.build();
1029+
}
9661030
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
package org.opensearch.ml.rest;
6+
7+
import static org.hamcrest.Matchers.allOf;
8+
import static org.hamcrest.Matchers.containsString;
9+
10+
import java.io.IOException;
11+
12+
import org.apache.http.HttpHeaders;
13+
import org.apache.http.message.BasicHeader;
14+
import org.junit.After;
15+
import org.opensearch.client.Response;
16+
import org.opensearch.client.ResponseException;
17+
import org.opensearch.ml.breaker.MemoryCircuitBreaker;
18+
import org.opensearch.ml.utils.TestHelper;
19+
20+
import com.google.common.collect.ImmutableList;
21+
22+
public class RestMLMemoryCircuitBreakerIT extends MLCommonsRestTestCase {
23+
@After
24+
public void tearDown() throws Exception {
25+
super.tearDown();
26+
// restore the threshold to default value
27+
Response response1 = TestHelper
28+
.makeRequest(
29+
client(),
30+
"PUT",
31+
"_cluster/settings",
32+
null,
33+
"{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":"
34+
+ MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD
35+
+ "}}",
36+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
37+
);
38+
assertEquals(200, response1.getStatusLine().getStatusCode());
39+
}
40+
41+
public void testRunWithMemoryCircuitBreaker() throws IOException {
42+
// set a low threshold
43+
Response response1 = TestHelper
44+
.makeRequest(
45+
client(),
46+
"PUT",
47+
"_cluster/settings",
48+
null,
49+
"{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":1}}",
50+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
51+
);
52+
assertEquals(200, response1.getStatusLine().getStatusCode());
53+
54+
// expect task fail due to memory limit
55+
Exception exception = assertThrows(ResponseException.class, () -> ingestModelData());
56+
org.hamcrest.MatcherAssert
57+
.assertThat(
58+
exception.getMessage(),
59+
allOf(
60+
containsString("Memory Circuit Breaker is open, please check your resources!"),
61+
containsString("circuit_breaking_exception")
62+
)
63+
);
64+
65+
// set a higher threshold
66+
Response response2 = TestHelper
67+
.makeRequest(
68+
client(),
69+
"PUT",
70+
"_cluster/settings",
71+
null,
72+
"{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":100}}",
73+
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
74+
);
75+
assertEquals(200, response2.getStatusLine().getStatusCode());
76+
77+
// expect task success
78+
ingestModelData();
79+
}
80+
}

0 commit comments

Comments
 (0)