Skip to content

Commit a3258c1

Browse files
committed
fix memory CB bugs
Signed-off-by: Xun Zhang <xunzh@amazon.com>
1 parent 99e75aa commit a3258c1

File tree

8 files changed

+66
-9
lines changed

8 files changed

+66
-9
lines changed

plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java

+3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.opensearch.core.xcontent.ToXContent;
2626
import org.opensearch.ml.common.FunctionName;
2727
import org.opensearch.ml.common.MLModel;
28+
import org.opensearch.ml.common.exception.MLLimitExceededException;
2829
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
2930
import org.opensearch.ml.common.input.MLInput;
3031
import org.opensearch.ml.common.transport.MLTaskResponse;
@@ -177,6 +178,8 @@ public void onResponse(MLModel mlModel) {
177178
);
178179
} else if (e instanceof MLResourceNotFoundException) {
179180
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.NOT_FOUND));
181+
} else if (e instanceof MLLimitExceededException) {
182+
wrappedListener.onFailure(new OpenSearchStatusException(e.getMessage(), RestStatus.SERVICE_UNAVAILABLE));
180183
} else {
181184
wrappedListener
182185
.onFailure(

plugin/src/main/java/org/opensearch/ml/breaker/MemoryCircuitBreaker.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,6 @@ public Short getThreshold() {
5050

5151
@Override
5252
public boolean isOpen() {
53-
return jvmService.stats().getMem().getHeapUsedPercent() > this.getThreshold();
53+
return getThreshold() < 100 && jvmService.stats().getMem().getHeapUsedPercent() > getThreshold();
5454
}
5555
}

plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java

+7-1
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,13 @@ public void dispatchTask(
143143
if (clusterService.localNode().getId().equals(node.getId())) {
144144
log.debug("Execute ML predict request {} locally on node {}", request.getRequestID(), node.getId());
145145
request.setDispatchTask(false);
146-
executeTask(request, listener);
146+
run(
147+
// This is by design to NOT use mlPredictionTaskRequest.getMlInput().getAlgorithm() here
148+
functionName,
149+
request,
150+
transportService,
151+
listener
152+
);
147153
} else {
148154
log.debug("Execute ML predict request {} remotely on node {}", request.getRequestID(), node.getId());
149155
request.setDispatchTask(false);

plugin/src/test/java/org/opensearch/ml/action/prediction/TransportPredictionTaskActionTests.java

+23
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.opensearch.ml.common.dataframe.DataFrameBuilder;
4444
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
4545
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
46+
import org.opensearch.ml.common.exception.MLLimitExceededException;
4647
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
4748
import org.opensearch.ml.common.input.MLInput;
4849
import org.opensearch.ml.common.input.parameter.clustering.KMeansParams;
@@ -235,6 +236,28 @@ public void testPrediction_MLResourceNotFoundException() {
235236
assertEquals("Testing MLResourceNotFoundException", argumentCaptor.getValue().getMessage());
236237
}
237238

239+
public void testPrediction_MLLimitExceededException() {
240+
when(modelCacheHelper.getModelInfo(anyString())).thenReturn(model);
241+
when(model.getAlgorithm()).thenReturn(FunctionName.TEXT_EMBEDDING);
242+
243+
doAnswer(invocation -> {
244+
ActionListener<Boolean> listener = invocation.getArgument(3);
245+
listener.onFailure(new MLLimitExceededException("Memory Circuit Breaker is open, please check your resources!"));
246+
return null;
247+
}).when(modelAccessControlHelper).validateModelGroupAccess(any(), any(), any(), any());
248+
249+
doAnswer(invocation -> {
250+
((ActionListener<MLTaskResponse>) invocation.getArguments()[3]).onResponse(null);
251+
return null;
252+
}).when(mlPredictTaskRunner).run(any(), any(), any(), any());
253+
254+
transportPredictionTaskAction.doExecute(null, mlPredictionTaskRequest, actionListener);
255+
256+
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
257+
verify(actionListener).onFailure(argumentCaptor.capture());
258+
assertEquals("Memory Circuit Breaker is open, please check your resources!", argumentCaptor.getValue().getMessage());
259+
}
260+
238261
public void testValidateInputSchemaSuccess() {
239262
RemoteInferenceInputDataSet remoteInferenceInputDataSet = RemoteInferenceInputDataSet
240263
.builder()

plugin/src/test/java/org/opensearch/ml/breaker/MemoryCircuitBreakerTests.java

+19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import org.mockito.MockitoAnnotations;
1616
import org.opensearch.cluster.service.ClusterService;
1717
import org.opensearch.common.settings.ClusterSettings;
18+
import org.opensearch.common.settings.Setting;
1819
import org.opensearch.common.settings.Settings;
1920
import org.opensearch.monitor.jvm.JvmService;
2021
import org.opensearch.monitor.jvm.JvmStats;
@@ -84,4 +85,22 @@ public void testIsOpen_UpdatedByClusterSettings_ExceedMemoryThreshold() {
8485
settingsService.applySettings(newSettingsBuilder.build());
8586
Assert.assertFalse(breaker.isOpen());
8687
}
88+
89+
@Test
90+
public void testIsOpen_DisableMemoryCB() {
91+
ClusterSettings settingsService = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
92+
settingsService.registerSetting(ML_COMMONS_JVM_HEAP_MEM_THRESHOLD);
93+
when(clusterService.getClusterSettings()).thenReturn(settingsService);
94+
95+
CircuitBreaker breaker = new MemoryCircuitBreaker(Settings.builder().build(), clusterService, jvmService);
96+
97+
when(mem.getHeapUsedPercent()).thenReturn((short) 90);
98+
Assert.assertTrue(breaker.isOpen());
99+
100+
when(mem.getHeapUsedPercent()).thenReturn((short) 100);
101+
Settings.Builder newSettingsBuilder = Settings.builder();
102+
newSettingsBuilder.put("plugins.ml_commons.jvm_heap_memory_threshold", 100);
103+
settingsService.applySettings(newSettingsBuilder.build());
104+
Assert.assertFalse(breaker.isOpen());
105+
}
87106
}

plugin/src/test/java/org/opensearch/ml/model/MLModelCacheHelperTests.java

+5-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.junit.rules.ExpectedException;
2525
import org.mockito.Mock;
2626
import org.mockito.MockitoAnnotations;
27+
import org.opensearch.cluster.ClusterManagerMetrics;
2728
import org.opensearch.cluster.service.ClusterService;
2829
import org.opensearch.common.settings.ClusterSettings;
2930
import org.opensearch.common.settings.Settings;
@@ -64,13 +65,16 @@ public class MLModelCacheHelperTests extends OpenSearchTestCase {
6465
@Mock
6566
private TokenBucket rateLimiter;
6667

68+
@Mock
69+
ClusterManagerMetrics clusterManagerMetrics;
70+
6771
@Before
6872
public void setup() {
6973
MockitoAnnotations.openMocks(this);
7074
maxMonitoringRequests = 10;
7175
settings = Settings.builder().put(ML_COMMONS_MONITORING_REQUEST_COUNT.getKey(), maxMonitoringRequests).build();
7276
ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_MONITORING_REQUEST_COUNT);
73-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
77+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics));
7478

7579
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
7680
cacheHelper = new MLModelCacheHelper(clusterService, settings);

plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import org.opensearch.action.update.UpdateRequest;
7777
import org.opensearch.action.update.UpdateResponse;
7878
import org.opensearch.client.Client;
79+
import org.opensearch.cluster.ClusterManagerMetrics;
7980
import org.opensearch.cluster.service.ClusterService;
8081
import org.opensearch.common.settings.ClusterSettings;
8182
import org.opensearch.common.settings.Settings;
@@ -177,7 +178,7 @@ public class MLModelManagerTests extends OpenSearchTestCase {
177178
private ScriptService scriptService;
178179

179180
@Mock
180-
private MLTask pretrainedMLTask;
181+
ClusterManagerMetrics clusterManagerMetrics;
181182

182183
@Before
183184
public void setup() throws URISyntaxException {
@@ -196,7 +197,7 @@ public void setup() throws URISyntaxException {
196197
ML_COMMONS_MONITORING_REQUEST_COUNT,
197198
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE
198199
);
199-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
200+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics));
200201
xContentRegistry = NamedXContentRegistry.EMPTY;
201202

202203
modelName = "model_name1";

plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java

+5-4
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import org.mockito.Mock;
2929
import org.mockito.MockitoAnnotations;
3030
import org.opensearch.client.Client;
31+
import org.opensearch.cluster.ClusterManagerMetrics;
3132
import org.opensearch.cluster.service.ClusterService;
3233
import org.opensearch.common.settings.ClusterSettings;
3334
import org.opensearch.common.settings.Settings;
@@ -70,14 +71,14 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase {
7071
@Mock
7172
MLCircuitBreakerService mlCircuitBreakerService;
7273

73-
@Mock
74-
TransportService transportService;
75-
7674
@Mock
7775
ActionListener<MLExecuteTaskResponse> listener;
7876
@Mock
7977
DiscoveryNodeHelper nodeHelper;
8078

79+
@Mock
80+
ClusterManagerMetrics clusterManagerMetrics;
81+
8182
@Rule
8283
public ExpectedException exceptionRule = ExpectedException.none();
8384

@@ -115,7 +116,7 @@ public void setup() {
115116
ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE,
116117
ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL
117118
);
118-
clusterService = spy(new ClusterService(settings, clusterSettings, null));
119+
clusterService = spy(new ClusterService(settings, clusterSettings, null, clusterManagerMetrics));
119120
when(clusterService.getClusterSettings()).thenReturn(clusterSettings);
120121

121122
Map<Enum, MLStat<?>> stats = new ConcurrentHashMap<>();

0 commit comments

Comments
 (0)