Skip to content

Commit f88b6d6

Browse files
authored
fix memory CB bugs and upgrade UTs to compatible with core changes (#2469)
* fix memory CB bugs Signed-off-by: Xun Zhang <xunzh@amazon.com> * change CB limit exception code to 429 and skip CB check for remote models Signed-off-by: Xun Zhang <xunzh@amazon.com> --------- Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 99e75aa commit f88b6d6

14 files changed

+107
-35
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import org.opensearch.common.xcontent.XContentFactory;
2121
import org.opensearch.commons.authuser.User;
2222
import org.opensearch.core.action.ActionListener;
23+
import org.opensearch.core.common.breaker.CircuitBreakingException;
2324
import org.opensearch.core.rest.RestStatus;
2425
import org.opensearch.core.xcontent.NamedXContentRegistry;
2526
import org.opensearch.core.xcontent.ToXContent;
@@ -177,6 +178,8 @@ public void onResponse(MLModel mlModel) {
177178
);
178179
} else if (e instanceof MLResourceNotFoundException) {
179180
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
181+
} else if (e instanceof CircuitBreakingException) {
182+
wrappedListener.onFailure(e);
180183
} else {
181184
wrappedListener
182185
.onFailure(

plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ public Short getThreshold() {
5050

5151
@Override
5252
public boolean isOpen() {
53-
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
53+
return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
5454
}
5555
}

plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java

+4-1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
import java.util.HashMap;
6060
import java.util.List;
6161
import java.util.Map;
62+
import java.util.Objects;
6263
import java.util.Optional;
6364
import java.util.Set;
6465
import java.util.concurrent.ConcurrentLinkedDeque;
@@ -838,7 +839,9 @@ private <T> ThreadedActionListener<T> threadedActionListener(String threadPoolNa
838839
* @param runningTaskLimit limit
839840
*/
840841
public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
841-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
842+
if (Objects.nonNull(mlTask) && mlTask.getFunctionName() != FunctionName.REMOTE) {
843+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
844+
}
842845
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
843846
}
844847

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ public void dispatchTask(
143143
if (clusterService.localNode().getId().equals(node.getId())) {
144144
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
145145
request.setDispatchTask(false);
146-
executeTask(request, listener);
146+
checkCBAndExecute(functionName, request, listener);
147147
} else {
148148
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
149149
request.setDispatchTask(false);

plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java

+8-2
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,7 @@ protected void handleAsyncMLTaskComplete(MLTask mlTask) {
8787
public void run(FunctionName functionName, Request request, TransportService transportService, ActionListener<Response> listener) {
8888
if (!request.isDispatchTask()) {
8989
log.debug("Run ML request {} locally", request.getRequestID());
90-
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
91-
executeTask(request, listener);
90+
checkCBAndExecute(functionName, request, listener);
9291
return;
9392
}
9493
dispatchTask(functionName, request, transportService, listener);
@@ -129,4 +128,11 @@ public void dispatchTask(
129128
protected abstract TransportResponseHandler<Response> getResponseHandler(ActionListener<Response> listener);
130129

131130
protected abstract void executeTask(Request request, ActionListener<Response> listener);
131+
132+
protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener<Response> listener) {
133+
if (functionName != FunctionName.REMOTE) {
134+
checkOpenCircuitBreaker(mlCircuitBreakerService, mlStats);
135+
}
136+
executeTask(request, listener);
137+
}
132138
}

plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java

+6-2
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
1919
import org.opensearch.common.xcontent.XContentHelper;
2020
import org.opensearch.common.xcontent.XContentType;
21+
import org.opensearch.core.common.breaker.CircuitBreaker;
22+
import org.opensearch.core.common.breaker.CircuitBreakingException;
2123
import org.opensearch.core.common.bytes.BytesReference;
2224
import org.opensearch.core.xcontent.NamedXContentRegistry;
2325
import org.opensearch.core.xcontent.XContentParser;
2426
import org.opensearch.ml.breaker.MLCircuitBreakerService;
2527
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
26-
import org.opensearch.ml.common.exception.MLLimitExceededException;
2728
import org.opensearch.ml.stats.MLNodeLevelStat;
2829
import org.opensearch.ml.stats.MLStats;
2930

@@ -92,7 +93,10 @@ public static void checkOpenCircuitBreaker(MLCircuitBreakerService mlCircuitBrea
9293
ThresholdCircuitBreaker openCircuitBreaker = mlCircuitBreakerService.checkOpenCB();
9394
if (openCircuitBreaker != null) {
9495
mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).increment();
95-
throw new MLLimitExceededException(openCircuitBreaker.getName() + " is open, please check your resources!");
96+
throw new CircuitBreakingException(
97+
openCircuitBreaker.getName() + " is open, please check your resources!",
98+
CircuitBreaker.Durability.TRANSIENT
99+
);
96100
}
97101
}
98102
}

plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java

+24
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
import org.opensearch.common.util.concurrent.ThreadContext;
3636
import org.opensearch.commons.authuser.User;
3737
import org.opensearch.core.action.ActionListener;
38+
import org.opensearch.core.common.breaker.CircuitBreaker;
39+
import org.opensearch.core.common.breaker.CircuitBreakingException;
3840
import org.opensearch.core.rest.RestStatus;
3941
import org.opensearch.core.xcontent.NamedXContentRegistry;
4042
import org.opensearch.ml.common.FunctionName;
@@ -235,6 +237,28 @@ public void testPrediction_MLResourceNotFoundException() {
235237
assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage());
236238
}
237239

240+
public void testPrediction_MLLimitExceededException() {
241+
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
242+
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
243+
244+
doAnswer(invocation -> {
245+
ActionListener<Boolean> listener = invocation.getArgument(3);
246+
listener.onFailure(new CircuitBreakingException("Memory Circuit Breaker is open, please check your resources!", CircuitBreaker.Durability.TRANSIENT));
247+
return null;
248+
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());
249+
250+
doAnswer(invocation -> {
251+
((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null);
252+
return null;
253+
}).when(mlPredictTaskRunner).run(any(), any(), any(), any());
254+
255+
transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener);
256+
257+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(CircuitBreakingException.class);
258+
verify(actionListener).onFailure(argumentCaptor.capture());
259+
assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage());
260+
}
261+
238262
public void testValidateInputSchemaSuccess() {
239263
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet
240264
.builder()

plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java

+18
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
8484
settingsService.applySettings(newSettingsBuilder.build());
8585
Assert.assertFalse(breaker.isOpen());
8686
}
87+
88+
@Test
89+
public void testIsOpen_DisableMemoryCB() {
90+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
91+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
92+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
93+
94+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
95+
96+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
97+
Assert.assertTrue(breaker.isOpen());
98+
99+
when(mem.getHeapUsedPercent()).thenReturn((short) 100);
100+
Settings.Builder newSettingsBuilder = Settings.builder();
101+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
102+
settingsService.applySettings(newSettingsBuilder.build());
103+
Assert.assertFalse(breaker.isOpen());
104+
}
87105
}

plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.rules.ExpectedException;
2525
import org.mockito.Mock;
2626
import org.mockito.MockitoAnnotations;
27+
import org.opensearch.cluster.service.ClusterApplierService;
2728
import org.opensearch.cluster.service.ClusterService;
2829
import org.opensearch.common.settings.ClusterSettings;
2930
import org.opensearch.common.settings.Settings;
@@ -64,13 +65,16 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase {
6465
@Mock
6566
private TokenBucket rateLimiter;
6667

68+
@Mock
69+
ClusterApplierService clusterApplierService;
70+
6771
@Before
6872
public void setup() {
6973
MockitoAnnotations.openMocks(this);
7074
maxMonitoringRequests = 10;
7175
settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), maxMonitoringRequests).build();
7276
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MONITORING_REQUEST_COUNT);
73-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
77+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
7478

7579
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
7680
cacheHelper = new MLModelCacheHelper(clusterService, settings);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+26-16
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,16 @@
7676
import org.opensearch.action.update.UpdateRequest;
7777
import org.opensearch.action.update.UpdateResponse;
7878
import org.opensearch.client.Client;
79+
import org.opensearch.cluster.service.ClusterApplierService;
7980
import org.opensearch.cluster.service.ClusterService;
8081
import org.opensearch.common.settings.ClusterSettings;
8182
import org.opensearch.common.settings.Settings;
8283
import org.opensearch.common.util.concurrent.ThreadContext;
8384
import org.opensearch.core.action.ActionListener;
85+
import org.opensearch.core.common.breaker.CircuitBreaker;
86+
import org.opensearch.core.common.breaker.CircuitBreakingException;
8487
import org.opensearch.core.xcontent.NamedXContentRegistry;
8588
import org.opensearch.ml.breaker.MLCircuitBreakerService;
86-
import org.opensearch.ml.breaker.MemoryCircuitBreaker;
8789
import org.opensearch.ml.breaker.ThresholdCircuitBreaker;
8890
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
8991
import org.opensearch.ml.common.FunctionName;
@@ -114,7 +116,6 @@
114116
import org.opensearch.ml.stats.MLStats;
115117
import org.opensearch.ml.stats.suppliers.CounterSupplier;
116118
import org.opensearch.ml.task.MLTaskManager;
117-
import org.opensearch.monitor.jvm.JvmService;
118119
import org.opensearch.script.ScriptService;
119120
import org.opensearch.test.OpenSearchTestCase;
120121
import org.opensearch.threadpool.ThreadPool;
@@ -177,7 +178,7 @@ public class MLModelManagerTests extends OpenSearchTestCase {
177178
private ScriptService scriptService;
178179

179180
@Mock
180-
private MLTask pretrainedMLTask;
181+
ClusterApplierService clusterApplierService;
181182

182183
@Before
183184
public void setup() throws URISyntaxException {
@@ -196,7 +197,7 @@ public void setup() throws URISyntaxException {
196197
ML_COMMONS_MONITORING_REQUEST_COUNT,
197198
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE
198199
);
199-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
200+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
200201
xContentRegistry = NamedXContentRegistry.EMPTY;
201202

202203
modelName = "model_name1";
@@ -323,7 +324,7 @@ public void testRegisterMLModel_CircuitBreakerOpen() {
323324
when(mlCircuitBreakerService.checkOpenCB()).thenReturn(thresholdCircuitBreaker);
324325
when(thresholdCircuitBreaker.getName()).thenReturn("Disk Circuit Breaker");
325326
when(thresholdCircuitBreaker.getThreshold()).thenReturn(87);
326-
expectedEx.expect(MLException.class);
327+
expectedEx.expect(CircuitBreakingException.class);
327328
expectedEx.expectMessage("Disk Circuit Breaker is open, please check your resources!");
328329
modelManager.registerMLModel(registerModelInput, mlTask);
329330
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
@@ -452,21 +453,30 @@ public void testRegisterMLRemoteModel() throws PrivilegedActionException {
452453
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
453454
}
454455

455-
public void testRegisterMLRemoteModel_WhenMemoryCBOpen_ThenFail() {
456+
public void testRegisterMLRemoteModel_SkipMemoryCBOpen() {
456457
ActionListener<MLRegisterModelResponse> listener = mock(ActionListener.class);
457-
MemoryCircuitBreaker memCB = new MemoryCircuitBreaker(mock(JvmService.class));
458-
String memCBIsOpenMessage = memCB.getName() + " is open, please check your resources!";
459-
when(mlCircuitBreakerService.checkOpenCB()).thenThrow(new MLLimitExceededException(memCBIsOpenMessage));
460-
458+
doNothing().when(mlTaskManager).checkLimitAndAddRunningTask(any(), any());
459+
when(mlCircuitBreakerService.checkOpenCB())
460+
.thenThrow(
461+
new CircuitBreakingException(
462+
"Memory Circuit Breaker is open, please check your resources!",
463+
CircuitBreaker.Durability.TRANSIENT
464+
)
465+
);
466+
when(threadPool.executor(REGISTER_THREAD_POOL)).thenReturn(taskExecutorService);
467+
when(modelHelper.isModelAllowed(any(), any())).thenReturn(true);
461468
MLRegisterModelInput pretrainedInput = mockRemoteModelInput(true);
462469
MLTask pretrainedTask = MLTask.builder().taskId("pretrained").modelId("pretrained").functionName(FunctionName.REMOTE).build();
470+
mock_MLIndicesHandler_initModelIndex(mlIndicesHandler, true);
471+
doAnswer(invocation -> {
472+
ActionListener<IndexResponse> indexResponseActionListener = (ActionListener<IndexResponse>) invocation.getArguments()[1];
473+
indexResponseActionListener.onResponse(indexResponse);
474+
return null;
475+
}).when(client).index(any(), any());
476+
when(indexResponse.getId()).thenReturn("mockIndexId");
463477
modelManager.registerMLRemoteModel(pretrainedInput, pretrainedTask, listener);
464-
465-
ArgumentCaptor<Exception> argCaptor = ArgumentCaptor.forClass(Exception.class);
466-
verify(listener, times(1)).onFailure(argCaptor.capture());
467-
Exception e = argCaptor.getValue();
468-
assertTrue(e instanceof MLLimitExceededException);
469-
assertEquals(memCBIsOpenMessage, e.getMessage());
478+
assertEquals(pretrainedTask.getFunctionName(), FunctionName.REMOTE);
479+
verify(mlTaskManager).updateMLTask(anyString(), anyMap(), anyLong(), anyBoolean());
470480
}
471481

472482
public void testIndexRemoteModel() throws PrivilegedActionException {

plugin/src/test/java/org/opensearch/ml/rest/RestMLMemoryCircuitBreakerIT.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ public void testRunWithMemoryCircuitBreaker() throws IOException {
5858
exception.getMessage(),
5959
allOf(
6060
containsString("Memory Circuit Breaker is open, please check your resources!"),
61-
containsString("m_l_limit_exceeded_exception")
61+
containsString("circuit_breaking_exception")
6262
)
6363
);
6464

plugin/src/test/java/org/opensearch/ml/rest/RestMLRemoteInferenceIT.java

+4-2
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ public void testPredictWithAutoDeployAndTTL_RemoteModel() throws IOException, In
199199
Response response = createConnector(completionModelConnectorEntity);
200200
Map responseMap = parseResponseToMap(response);
201201
String connectorId = (String) responseMap.get("connector_id");
202-
response = registerRemoteModelWithTTL("openAI-GPT-3.5 completions", connectorId, 1);
202+
response = registerRemoteModelWithTTLAndSkipHeapMemCheck("openAI-GPT-3.5 completions", connectorId, 1);
203203
responseMap = parseResponseToMap(response);
204204
String modelId = (String) responseMap.get("model_id");
205205
String predictInput = "{\n" + " \"parameters\": {\n" + " \"prompt\": \"Say this is a test\"\n" + " }\n" + "}";
@@ -814,11 +814,13 @@ public static Response registerRemoteModel(String name, String connectorId) thro
814814
.makeRequest(client(), "POST", "/_plugins/_ml/models/_register", null, TestHelper.toHttpEntity(registerModelEntity), null);
815815
}
816816

817-
public static Response registerRemoteModelWithTTL(String name, String connectorId, int ttl) throws IOException {
817+
public static Response registerRemoteModelWithTTLAndSkipHeapMemCheck(String name, String connectorId, int ttl) throws IOException {
818818
String registerModelGroupEntity = "{\n"
819819
+ " \"name\": \"remote_model_group\",\n"
820820
+ " \"description\": \"This is an example description\"\n"
821821
+ "}";
822+
String updateJVMHeapThreshold = "{\"persistent\":{\"plugins.ml_commons.jvm_heap_memory_threshold\":0}}";
823+
TestHelper.makeRequest(client(), "PUT", "/_cluster/settings", null, TestHelper.toHttpEntity(updateJVMHeapThreshold), null);
822824
Response response = TestHelper
823825
.makeRequest(
824826
client(),

plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.mockito.Mock;
2929
import org.mockito.MockitoAnnotations;
3030
import org.opensearch.client.Client;
31+
import org.opensearch.cluster.service.ClusterApplierService;
3132
import org.opensearch.cluster.service.ClusterService;
3233
import org.opensearch.common.settings.ClusterSettings;
3334
import org.opensearch.common.settings.Settings;
@@ -48,7 +49,6 @@
4849
import org.opensearch.ml.stats.suppliers.CounterSupplier;
4950
import org.opensearch.test.OpenSearchTestCase;
5051
import org.opensearch.threadpool.ThreadPool;
51-
import org.opensearch.transport.TransportService;
5252

5353
public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
5454

@@ -70,13 +70,12 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
7070
@Mock
7171
MLCircuitBreakerService mlCircuitBreakerService;
7272

73-
@Mock
74-
TransportService transportService;
75-
7673
@Mock
7774
ActionListener<MLExecuteTaskResponse> listener;
7875
@Mock
7976
DiscoveryNodeHelper nodeHelper;
77+
@Mock
78+
ClusterApplierService clusterApplierService;
8079

8180
@Rule
8281
public ExpectedException exceptionRule = ExpectedException.none();
@@ -115,7 +114,7 @@ public void setup() {
115114
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE,
116115
ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL
117116
);
118-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
117+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
119118
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
120119

121120
Map<Enum, MLStat<?>> stats = new ConcurrentHashMap<>();

plugin/src/test/java/org/opensearch/ml/task/TaskRunnerTests.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
import org.opensearch.ml.common.MLTask;
3535
import org.opensearch.ml.common.MLTaskState;
3636
import org.opensearch.ml.common.MLTaskType;
37-
import org.opensearch.ml.common.exception.MLLimitExceededException;
3837
import org.opensearch.ml.common.transport.MLTaskRequest;
3938
import org.opensearch.ml.stats.MLNodeLevelStat;
4039
import org.opensearch.ml.stats.MLStat;
@@ -139,8 +138,8 @@ public void testRun_CircuitBreakerOpen() {
139138
TransportService transportService = mock(TransportService.class);
140139
ActionListener listener = mock(ActionListener.class);
141140
MLTaskRequest request = new MLTaskRequest(false);
142-
expectThrows(MLLimitExceededException.class, () -> mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener));
141+
mlTaskRunner.run(FunctionName.REMOTE, request, transportService, listener);
143142
Long value = (Long) mlStats.getStat(MLNodeLevelStat.ML_CIRCUIT_BREAKER_TRIGGER_COUNT).getValue();
144-
assertEquals(1L, value.longValue());
143+
assertEquals(0L, value.longValue());
145144
}
146145
}

0 commit comments

Comments
 (0)