Skip to content

Commit 70ea17f

Browse files
committed
fix memory CB bugs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 99e75aa commit 70ea17f

File tree

8 files changed

+64
-10
lines changed

8 files changed

+64
-10
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.core.xcontent.ToXContent;
2626
import org.opensearch.ml.common.FunctionName;
2727
import org.opensearch.ml.common.MLModel;
28+
import org.opensearch.ml.common.exception.MLLimitExceededException;
2829
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
2930
import org.opensearch.ml.common.input.MLInput;
3031
import org.opensearch.ml.common.transport.MLTaskResponse;
@@ -177,6 +178,8 @@ public void onResponse(MLModel mlModel) {
177178
);
178179
} else if (e instanceof MLResourceNotFoundException) {
179180
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
181+
} else if (e instanceof MLLimitExceededException) {
182+
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE));
180183
} else {
181184
wrappedListener
182185
.onFailure(

plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ public Short getThreshold() {
5050

5151
@Override
5252
public boolean isOpen() {
53-
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
53+
return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
5454
}
5555
}

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,13 @@ public void dispatchTask(
143143
if (clusterService.localNode().getId().equals(node.getId())) {
144144
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
145145
request.setDispatchTask(false);
146-
executeTask(request, listener);
146+
run(
147+
// This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here
148+
functionName,
149+
request,
150+
transportService,
151+
listener
152+
);
147153
} else {
148154
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
149155
request.setDispatchTask(false);

plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java

+23
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
4444
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
4545
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
46+
import org.opensearch.ml.common.exception.MLLimitExceededException;
4647
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
4748
import org.opensearch.ml.common.input.MLInput;
4849
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
@@ -235,6 +236,28 @@ public void testPrediction_MLResourceNotFoundException() {
235236
assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage());
236237
}
237238

239+
public void testPrediction_MLLimitExceededException() {
240+
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
241+
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
242+
243+
doAnswer(invocation -> {
244+
ActionListener<Boolean> listener = invocation.getArgument(3);
245+
listener.onFailure(new MLLimitExceededException("Memory Circuit Breaker is open, please check your resources!"));
246+
return null;
247+
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());
248+
249+
doAnswer(invocation -> {
250+
((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null);
251+
return null;
252+
}).when(mlPredictTaskRunner).run(any(), any(), any(), any());
253+
254+
transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener);
255+
256+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
257+
verify(actionListener).onFailure(argumentCaptor.capture());
258+
assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage());
259+
}
260+
238261
public void testValidateInputSchemaSuccess() {
239262
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet
240263
.builder()

plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java

+18
Original file line numberDiff line numberDiff line change
@@ -84,4 +84,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
8484
settingsService.applySettings(newSettingsBuilder.build());
8585
Assert.assertFalse(breaker.isOpen());
8686
}
87+
88+
@Test
89+
public void testIsOpen_DisableMemoryCB() {
90+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
91+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
92+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
93+
94+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
95+
96+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
97+
Assert.assertTrue(breaker.isOpen());
98+
99+
when(mem.getHeapUsedPercent()).thenReturn((short) 100);
100+
Settings.Builder newSettingsBuilder = Settings.builder();
101+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
102+
settingsService.applySettings(newSettingsBuilder.build());
103+
Assert.assertFalse(breaker.isOpen());
104+
}
87105
}

plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.rules.ExpectedException;
2525
import org.mockito.Mock;
2626
import org.mockito.MockitoAnnotations;
27+
import org.opensearch.cluster.service.ClusterApplierService;
2728
import org.opensearch.cluster.service.ClusterService;
2829
import org.opensearch.common.settings.ClusterSettings;
2930
import org.opensearch.common.settings.Settings;
@@ -64,13 +65,16 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase {
6465
@Mock
6566
private TokenBucket rateLimiter;
6667

68+
@Mock
69+
ClusterApplierService clusterApplierService;
70+
6771
@Before
6872
public void setup() {
6973
MockitoAnnotations.openMocks(this);
7074
maxMonitoringRequests = 10;
7175
settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), maxMonitoringRequests).build();
7276
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MONITORING_REQUEST_COUNT);
73-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
77+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
7478

7579
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
7680
cacheHelper = new MLModelCacheHelper(clusterService, settings);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.opensearch.action.update.UpdateRequest;
7777
import org.opensearch.action.update.UpdateResponse;
7878
import org.opensearch.client.Client;
79+
import org.opensearch.cluster.service.ClusterApplierService;
7980
import org.opensearch.cluster.service.ClusterService;
8081
import org.opensearch.common.settings.ClusterSettings;
8182
import org.opensearch.common.settings.Settings;
@@ -177,7 +178,7 @@ public class MLModelManagerTests extends OpenSearchTestCase {
177178
private ScriptService scriptService;
178179

179180
@Mock
180-
private MLTask pretrainedMLTask;
181+
ClusterApplierService clusterApplierService;
181182

182183
@Before
183184
public void setup() throws URISyntaxException {
@@ -196,7 +197,7 @@ public void setup() throws URISyntaxException {
196197
ML_COMMONS_MONITORING_REQUEST_COUNT,
197198
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE
198199
);
199-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
200+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
200201
xContentRegistry = NamedXContentRegistry.EMPTY;
201202

202203
modelName = "model_name1";

plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java

+4-5
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.mockito.Mock;
2929
import org.mockito.MockitoAnnotations;
3030
import org.opensearch.client.Client;
31+
import org.opensearch.cluster.service.ClusterApplierService;
3132
import org.opensearch.cluster.service.ClusterService;
3233
import org.opensearch.common.settings.ClusterSettings;
3334
import org.opensearch.common.settings.Settings;
@@ -48,7 +49,6 @@
4849
import org.opensearch.ml.stats.suppliers.CounterSupplier;
4950
import org.opensearch.test.OpenSearchTestCase;
5051
import org.opensearch.threadpool.ThreadPool;
51-
import org.opensearch.transport.TransportService;
5252

5353
public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
5454

@@ -70,13 +70,12 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
7070
@Mock
7171
MLCircuitBreakerService mlCircuitBreakerService;
7272

73-
@Mock
74-
TransportService transportService;
75-
7673
@Mock
7774
ActionListener<MLExecuteTaskResponse> listener;
7875
@Mock
7976
DiscoveryNodeHelper nodeHelper;
77+
@Mock
78+
ClusterApplierService clusterApplierService;
8079

8180
@Rule
8281
public ExpectedException exceptionRule = ExpectedException.none();
@@ -115,7 +114,7 @@ public void setup() {
115114
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE,
116115
ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL
117116
);
118-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
117+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterApplierService));
119118
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
120119

121120
Map<Enum, MLStat<?>> stats = new ConcurrentHashMap<>();

0 commit comments

Comments
 (0)